Skip to content

Commit bb75afb

Browse files
authored
Merge pull request #2176 from coqui-ai/wav2vec2-decoder
Wav2vec2 decoder
2 parents c9e73ee + e36e731 commit bb75afb

29 files changed

+1020
-514
lines changed

.github/workflows/build-and-test.yml

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -274,17 +274,17 @@ jobs:
274274
fetch-depth: 1
275275
- uses: actions/setup-python@v2
276276
with:
277-
python-version: 3.6
277+
python-version: "3.7"
278278
- uses: actions/download-artifact@v2
279279
with:
280-
name: "coqui_stt_ctcdecoder-Linux-3.6.whl"
280+
name: "coqui_stt_ctcdecoder-Linux-3.7.whl"
281281
- run: |
282282
python --version
283283
pip --version
284284
- run: |
285285
pip install --upgrade pip setuptools wheel
286286
- run: |
287-
pip install coqui_stt_ctcdecoder-*-cp36-cp36m-*_x86_64.whl
287+
pip install coqui_stt_ctcdecoder-*-cp37-cp37m-*_x86_64.whl
288288
DS_NODECODER=y pip install --upgrade .
289289
- run: |
290290
# Easier to rename to that we can exercize the LDC93S1 importer code to
@@ -540,7 +540,7 @@ jobs:
540540
if: ${{ github.event_name == 'pull_request' }}
541541
strategy:
542542
matrix:
543-
python-version: ["3.6", "3.7"]
543+
python-version: ["3.7"]
544544
samplerate: ["8000", "16000"]
545545
env:
546546
CI_TMP_DIR: ${{ github.workspace }}/tmp/
@@ -700,7 +700,7 @@ jobs:
700700
- run: |
701701
python -m pip install --upgrade pip setuptools wheel jupyter
702702
- run: |
703-
python -m pip install coqui_stt_ctcdecoder-*-cp37-cp37m-*_x86_64.whl
703+
python -m pip install coqui_stt_ctcdecoder*.whl
704704
DS_NODECODER=y python -m pip install --upgrade .
705705
- name: Run python notebooks
706706
run: |
@@ -713,7 +713,7 @@ jobs:
713713
strategy:
714714
matrix:
715715
samplerate: ["8000", "16000"]
716-
pyver: [3.6, 3.7]
716+
pyver: ["3.7"]
717717
steps:
718718
- uses: actions/checkout@v2
719719
- uses: actions/setup-python@v2
@@ -779,7 +779,7 @@ jobs:
779779
strategy:
780780
matrix:
781781
samplerate: ["8000", "16000"]
782-
pyver: [3.6, 3.7]
782+
pyver: ["3.7"]
783783
steps:
784784
- uses: actions/checkout@v2
785785
- uses: actions/setup-python@v2
@@ -830,7 +830,7 @@ jobs:
830830
strategy:
831831
matrix:
832832
samplerate: ["8000", "16000"]
833-
pyver: [3.6, 3.7]
833+
pyver: ["3.7"]
834834
steps:
835835
- uses: actions/checkout@v2
836836
- uses: actions/setup-python@v2
@@ -874,7 +874,7 @@ jobs:
874874
strategy:
875875
matrix:
876876
samplerate: ["8000", "16000"]
877-
pyver: [3.6, 3.7]
877+
pyver: ["3.7"]
878878
steps:
879879
- uses: actions/checkout@v2
880880
- uses: actions/setup-python@v2
@@ -914,7 +914,7 @@ jobs:
914914
strategy:
915915
matrix:
916916
samplerate: ["8000", "16000"]
917-
pyver: [3.6, 3.7]
917+
pyver: ["3.7"]
918918
steps:
919919
- uses: actions/checkout@v2
920920
- uses: actions/setup-python@v2
@@ -950,7 +950,7 @@ jobs:
950950
strategy:
951951
matrix:
952952
samplerate: ["8000", "16000"]
953-
pyver: [3.6, 3.7]
953+
pyver: ["3.7"]
954954
steps:
955955
- uses: actions/checkout@v2
956956
- uses: actions/setup-python@v2
@@ -1193,6 +1193,8 @@ jobs:
11931193
docker-publish:
11941194
name: "Build and publish Docker training image to GHCR"
11951195
runs-on: ubuntu-20.04
1196+
needs: [upload-nc-release-assets]
1197+
if: always()
11961198
steps:
11971199
- uses: actions/checkout@v2
11981200
with:

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ repos:
77
- id: end-of-file-fixer
88
- id: trailing-whitespace
99
- repo: 'https://github.com/psf/black'
10-
rev: "22.1.0"
10+
rev: "22.3.0"
1111
hooks:
1212
- id: black
1313
language_version: python3

Dockerfile.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ RUN bazel build \
146146
//native_client:libstt.so
147147

148148
# Copy built libs to /STT/native_client
149-
RUN cp bazel-bin/native_client/libstt.so /STT/native_client/
149+
RUN cp bazel-bin/native_client/libstt.so bazel-bin/native_client/libkenlm.so /STT/native_client/
150150

151151
# Build client.cc and install Python client and decoder bindings
152152
ENV TFDIR /STT/tensorflow

data/smoke_test/LDC93S1.wav

45.7 KB
Binary file not shown.

lm_optimizer.py

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,16 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
3-
from __future__ import absolute_import, print_function
3+
from __future__ import absolute_import, division, print_function
44

5-
import sys
6-
7-
from coqui_stt_training.train import early_training_checks
8-
from coqui_stt_training.util.config import (
9-
Config,
10-
initialize_globals_from_cli,
11-
log_error,
12-
)
13-
14-
from coqui_stt_training.util import lm_optimize as lm_opt
15-
16-
17-
def main():
18-
initialize_globals_from_cli()
19-
early_training_checks()
20-
21-
if not Config.scorer_path:
22-
log_error(
23-
"Missing --scorer_path: can't optimize scorer alpha and beta "
24-
"parameters without a scorer!"
25-
)
26-
sys.exit(1)
27-
28-
if not Config.test_files:
29-
log_error(
30-
"You need to specify what files to use for evaluation via "
31-
"the --test_files flag."
32-
)
33-
sys.exit(1)
34-
35-
results = lm_opt.compute_lm_optimization()
5+
if __name__ == "__main__":
366
print(
37-
"Best params: lm_alpha={} and lm_beta={} with WER={}".format(
38-
results.get("lm_alpha"),
39-
results.get("lm_beta"),
40-
results.get("wer"),
41-
)
7+
"Using the top level lm_optimizer.py script is deprecated and will be removed "
8+
"in a future release. Instead use: python -m coqui_stt_training.util.lm_optimize"
429
)
10+
try:
11+
from coqui_stt_training.util import lm_optimize
12+
except ImportError:
13+
print("Training package is not installed. See training documentation.")
14+
raise
4315

44-
45-
if __name__ == "__main__":
46-
main()
16+
lm_optimize.main()

native_client/alphabet.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ Alphabet::SerializeText()
9090
<< "# A line that starts with # is a comment. You can escape it with \\# if you wish\n"
9191
<< "# to use '#' in the Alphabet.\n";
9292

93-
for (int idx = 0; idx < entrySize(); ++idx) {
94-
out << getEntry(idx) << "\n";
93+
for (const std::string& label : GetLabels()) {
94+
out << label << "\n";
9595
}
9696

9797
out << "# The last (non-comment) line needs to end with a newline.\n";
@@ -174,6 +174,16 @@ Alphabet::GetSize() const
174174
return entrySize();
175175
}
176176

177+
std::vector<std::string>
178+
Alphabet::GetLabels() const
179+
{
180+
std::vector<std::string> labels;
181+
for (int idx = 0; idx < GetSize(); ++idx) {
182+
labels.push_back(DecodeSingle(idx));
183+
}
184+
return labels;
185+
}
186+
177187
bool
178188
Alphabet::CanEncodeSingle(const std::string& input) const
179189
{

native_client/alphabet.h

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,41 +36,42 @@ class Alphabet : public fl::lib::text::Dictionary
3636

3737
size_t GetSize() const;
3838

39-
bool IsSpace(unsigned int label) const {
40-
return label == space_index_;
39+
bool IsSpace(unsigned int index) const {
40+
return index == space_index_;
4141
}
4242

4343
unsigned int GetSpaceLabel() const {
4444
return space_index_;
4545
}
4646

47-
// Returns true if the single character/output class has a corresponding label
47+
virtual std::vector<std::string> GetLabels() const;
48+
49+
// Returns true if the single character/output class has a corresponding index
4850
// in the alphabet.
49-
virtual bool CanEncodeSingle(const std::string& string) const;
51+
virtual bool CanEncodeSingle(const std::string& label) const;
5052

51-
// Returns true if the entire string can be encoded into labels in this
52-
// alphabet.
53-
virtual bool CanEncode(const std::string& string) const;
53+
// Returns true if the entire string can be encoded with this alphabet.
54+
virtual bool CanEncode(const std::string& label) const;
5455

55-
// Decode a single label into a string.
56-
std::string DecodeSingle(unsigned int label) const;
56+
// Decode a single index into its label.
57+
std::string DecodeSingle(unsigned int index) const;
5758

58-
// Encode a single character/output class into a label. Character must be in
59+
// Encode a single character/output class into its index. Character must be in
5960
// the alphabet, this method will assert that. Use `CanEncodeSingle` to test.
60-
unsigned int EncodeSingle(const std::string& string) const;
61+
unsigned int EncodeSingle(const std::string& label) const;
6162

62-
// Decode a sequence of labels into a string.
63-
std::string Decode(const std::vector<unsigned int>& input) const;
63+
// Decode a sequence of indices into a string.
64+
std::string Decode(const std::vector<unsigned int>& indices) const;
6465

6566
// We provide a C-style overload for accepting NumPy arrays as input, since
6667
// the NumPy library does not have built-in typemaps for std::vector<T>.
67-
std::string Decode(const unsigned int* input, int length) const;
68+
std::string Decode(const unsigned int* indices, int length) const;
6869

69-
// Encode a sequence of character/output classes into a sequence of labels.
70+
// Encode a sequence of character/output classes into a sequence of indices.
7071
// Characters are assumed to always take a single Unicode codepoint.
7172
// Characters must be in the alphabet, this method will assert that. Use
7273
// `CanEncode` and `CanEncodeSingle` to test.
73-
virtual std::vector<unsigned int> Encode(const std::string& input) const;
74+
virtual std::vector<unsigned int> Encode(const std::string& labels) const;
7475

7576
protected:
7677
unsigned int space_index_;
@@ -93,9 +94,9 @@ class UTF8Alphabet : public Alphabet
9394
return 0;
9495
}
9596

96-
bool CanEncodeSingle(const std::string& string) const override;
97-
bool CanEncode(const std::string& string) const override;
98-
std::vector<unsigned int> Encode(const std::string& input) const override;
97+
bool CanEncodeSingle(const std::string& label) const override;
98+
bool CanEncode(const std::string& label) const override;
99+
std::vector<unsigned int> Encode(const std::string& label) const override;
99100
};
100101

101102
#endif //ALPHABET_H

native_client/ctcdecode/__init__.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,69 @@ def ctc_beam_search_decoder(
178178
return beam_results
179179

180180

181+
def ctc_beam_search_decoder_for_wav2vec2am(
182+
probs_seq,
183+
alphabet,
184+
beam_size,
185+
cutoff_prob=1.0,
186+
cutoff_top_n=40,
187+
blank_id=-1,
188+
ignored_symbols=frozenset(),
189+
scorer=None,
190+
hot_words=dict(),
191+
num_results=1,
192+
):
193+
"""Wrapper for the CTC Beam Search Decoder.
194+
195+
:param probs_seq: 2-D list of probability distributions over each time
196+
step, with each element being a list of normalized
197+
probabilities over alphabet and blank.
198+
:type probs_seq: 2-D list
199+
:param alphabet: Alphabet
200+
:param beam_size: Width for beam search.
201+
:type beam_size: int
202+
:param cutoff_prob: Cutoff probability in pruning,
203+
default 1.0, no pruning.
204+
:type cutoff_prob: float
205+
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
206+
characters with highest probs in alphabet will be
207+
used in beam search, default 40.
208+
:type cutoff_top_n: int
209+
:param scorer: External scorer for partially decoded sentence, e.g. word
210+
count or language model.
211+
:type scorer: Scorer
212+
:param hot_words: Map of words (keys) to their assigned boosts (values)
213+
:type hot_words: dict[string, float]
214+
:param num_results: Number of beams to return.
215+
:type num_results: int
216+
:return: List of tuples of confidence and sentence as decoding
217+
results, in descending order of the confidence.
218+
:rtype: list
219+
"""
220+
beam_results = swigwrapper.ctc_beam_search_decoder_for_wav2vec2am(
221+
probs_seq,
222+
alphabet,
223+
beam_size,
224+
cutoff_prob,
225+
cutoff_top_n,
226+
blank_id,
227+
ignored_symbols,
228+
scorer,
229+
hot_words,
230+
num_results,
231+
)
232+
beam_results = [
233+
DecodeResult(
234+
res.confidence,
235+
alphabet.Decode(res.tokens),
236+
[int(t) for t in res.tokens],
237+
[int(t) for t in res.timesteps],
238+
)
239+
for res in beam_results
240+
]
241+
return beam_results
242+
243+
181244
def ctc_beam_search_decoder_batch(
182245
probs_seq,
183246
seq_lengths,

0 commit comments

Comments
 (0)