Skip to content
This repository was archived by the owner on Jun 19, 2025. It is now read-only.

Commit 69538f2

Browse files
authored
Merge pull request #2121 from dabinat/streaming-decoder
CTC streaming decoder
2 parents df5bb31 + d9a2694 commit 69538f2

File tree

4 files changed

+220
-91
lines changed

4 files changed

+220
-91
lines changed

native_client/ctcdecode/ctc_beam_search_decoder.cpp

Lines changed: 86 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,49 +14,61 @@
1414

1515
using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
1616

17-
std::vector<Output> ctc_beam_search_decoder(
18-
const double *probs,
19-
int time_dim,
20-
int class_dim,
21-
const Alphabet &alphabet,
22-
size_t beam_size,
23-
double cutoff_prob,
24-
size_t cutoff_top_n,
25-
Scorer *ext_scorer) {
17+
DecoderState* decoder_init(const Alphabet &alphabet,
18+
int class_dim,
19+
Scorer* ext_scorer) {
20+
2621
// dimension check
2722
VALID_CHECK_EQ(class_dim, alphabet.GetSize()+1,
2823
"The shape of probs does not match with "
2924
"the shape of the vocabulary");
3025

3126
// assign special ids
32-
int space_id = alphabet.GetSpaceLabel();
33-
int blank_id = alphabet.GetSize();
27+
DecoderState *state = new DecoderState;
28+
state->space_id = alphabet.GetSpaceLabel();
29+
state->blank_id = alphabet.GetSize();
3430

3531
// init prefixes' root
36-
PathTrie root;
37-
root.score = root.log_prob_b_prev = 0.0;
38-
std::vector<PathTrie *> prefixes;
39-
prefixes.push_back(&root);
32+
PathTrie *root = new PathTrie;
33+
root->score = root->log_prob_b_prev = 0.0;
34+
35+
state->prefix_root = root;
36+
37+
state->prefixes.push_back(root);
4038

4139
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
4240
auto dict_ptr = ext_scorer->dictionary->Copy(true);
43-
root.set_dictionary(dict_ptr);
41+
root->set_dictionary(dict_ptr);
4442
auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
45-
root.set_matcher(matcher);
43+
root->set_matcher(matcher);
4644
}
45+
46+
return state;
47+
}
48+
49+
void decoder_next(const double *probs,
50+
const Alphabet &alphabet,
51+
DecoderState *state,
52+
int time_dim,
53+
int class_dim,
54+
double cutoff_prob,
55+
size_t cutoff_top_n,
56+
size_t beam_size,
57+
Scorer *ext_scorer) {
4758

48-
// prefix search over time
59+
// prefix search over time
4960
for (size_t time_step = 0; time_step < time_dim; ++time_step) {
5061
auto *prob = &probs[time_step*class_dim];
5162

5263
float min_cutoff = -NUM_FLT_INF;
5364
bool full_beam = false;
5465
if (ext_scorer != nullptr) {
55-
size_t num_prefixes = std::min(prefixes.size(), beam_size);
66+
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
5667
std::sort(
57-
prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
58-
min_cutoff = prefixes[num_prefixes - 1]->score +
59-
std::log(prob[blank_id]) - std::max(0.0, ext_scorer->beta);
68+
state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
69+
70+
min_cutoff = state->prefixes[num_prefixes - 1]->score +
71+
std::log(prob[state->blank_id]) - std::max(0.0, ext_scorer->beta);
6072
full_beam = (num_prefixes == beam_size);
6173
}
6274

@@ -67,22 +79,25 @@ std::vector<Output> ctc_beam_search_decoder(
6779
auto c = log_prob_idx[index].first;
6880
auto log_prob_c = log_prob_idx[index].second;
6981

70-
for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
71-
auto prefix = prefixes[i];
82+
for (size_t i = 0; i < state->prefixes.size() && i < beam_size; ++i) {
83+
auto prefix = state->prefixes[i];
7284
if (full_beam && log_prob_c + prefix->score < min_cutoff) {
7385
break;
7486
}
87+
7588
// blank
76-
if (c == blank_id) {
89+
if (c == state->blank_id) {
7790
prefix->log_prob_b_cur =
7891
log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score);
7992
continue;
8093
}
94+
8195
// repeated character
8296
if (c == prefix->character) {
8397
prefix->log_prob_nb_cur = log_sum_exp(
8498
prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev);
8599
}
100+
86101
// get new prefix
87102
auto prefix_new = prefix->get_path_trie(c, time_step, log_prob_c);
88103

@@ -98,7 +113,7 @@ std::vector<Output> ctc_beam_search_decoder(
98113

99114
// language model scoring
100115
if (ext_scorer != nullptr &&
101-
(c == space_id || ext_scorer->is_character_based())) {
116+
(c == state->space_id || ext_scorer->is_character_based())) {
102117
PathTrie *prefix_to_score = nullptr;
103118
// skip scoring the space
104119
if (ext_scorer->is_character_based()) {
@@ -114,34 +129,41 @@ std::vector<Output> ctc_beam_search_decoder(
114129
log_p += score;
115130
log_p += ext_scorer->beta;
116131
}
132+
117133
prefix_new->log_prob_nb_cur =
118134
log_sum_exp(prefix_new->log_prob_nb_cur, log_p);
119135
}
120136
} // end of loop over prefix
121137
} // end of loop over vocabulary
122-
123-
124-
prefixes.clear();
138+
125139
// update log probs
126-
root.iterate_to_vec(prefixes);
140+
state->prefixes.clear();
141+
state->prefix_root->iterate_to_vec(state->prefixes);
127142

128143
// only preserve top beam_size prefixes
129-
if (prefixes.size() >= beam_size) {
130-
std::nth_element(prefixes.begin(),
131-
prefixes.begin() + beam_size,
132-
prefixes.end(),
144+
if (state->prefixes.size() >= beam_size) {
145+
std::nth_element(state->prefixes.begin(),
146+
state->prefixes.begin() + beam_size,
147+
state->prefixes.end(),
133148
prefix_compare);
134-
for (size_t i = beam_size; i < prefixes.size(); ++i) {
135-
prefixes[i]->remove();
149+
for (size_t i = beam_size; i < state->prefixes.size(); ++i) {
150+
state->prefixes[i]->remove();
136151
}
137152
}
153+
138154
} // end of loop over time
155+
}
156+
157+
std::vector<Output> decoder_decode(DecoderState *state,
158+
const Alphabet &alphabet,
159+
size_t beam_size,
160+
Scorer* ext_scorer) {
139161

140162
// score the last word of each prefix that doesn't end with space
141163
if (ext_scorer != nullptr && !ext_scorer->is_character_based()) {
142-
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
143-
auto prefix = prefixes[i];
144-
if (!prefix->is_empty() && prefix->character != space_id) {
164+
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
165+
auto prefix = state->prefixes[i];
166+
if (!prefix->is_empty() && prefix->character != state->space_id) {
145167
float score = 0.0;
146168
std::vector<std::string> ngram = ext_scorer->make_ngram(prefix);
147169
score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha;
@@ -151,30 +173,48 @@ std::vector<Output> ctc_beam_search_decoder(
151173
}
152174
}
153175

154-
size_t num_prefixes = std::min(prefixes.size(), beam_size);
155-
std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare);
176+
size_t num_prefixes = std::min(state->prefixes.size(), beam_size);
177+
std::sort(state->prefixes.begin(), state->prefixes.begin() + num_prefixes, prefix_compare);
156178

157179
// compute aproximate ctc score as the return score, without affecting the
158180
// return order of decoding result. To delete when decoder gets stable.
159-
for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
160-
double approx_ctc = prefixes[i]->score;
181+
for (size_t i = 0; i < beam_size && i < state->prefixes.size(); ++i) {
182+
double approx_ctc = state->prefixes[i]->score;
161183
if (ext_scorer != nullptr) {
162184
std::vector<int> output;
163185
std::vector<int> timesteps;
164-
prefixes[i]->get_path_vec(output, timesteps);
186+
state->prefixes[i]->get_path_vec(output, timesteps);
165187
auto prefix_length = output.size();
166188
auto words = ext_scorer->split_labels(output);
167189
// remove word insert
168190
approx_ctc = approx_ctc - prefix_length * ext_scorer->beta;
169191
// remove language model weight:
170192
approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha;
171193
}
172-
prefixes[i]->approx_ctc = approx_ctc;
194+
state->prefixes[i]->approx_ctc = approx_ctc;
173195
}
174196

175-
return get_beam_search_result(prefixes, beam_size);
197+
return get_beam_search_result(state->prefixes, beam_size);
176198
}
177199

200+
std::vector<Output> ctc_beam_search_decoder(
201+
const double *probs,
202+
int time_dim,
203+
int class_dim,
204+
const Alphabet &alphabet,
205+
size_t beam_size,
206+
double cutoff_prob,
207+
size_t cutoff_top_n,
208+
Scorer *ext_scorer) {
209+
210+
DecoderState *state = decoder_init(alphabet, class_dim, ext_scorer);
211+
decoder_next(probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
212+
std::vector<Output> out = decoder_decode(state, alphabet, beam_size, ext_scorer);
213+
214+
delete state;
215+
216+
return out;
217+
}
178218

179219
std::vector<std::vector<Output>>
180220
ctc_beam_search_decoder_batch(

native_client/ctcdecode/ctc_beam_search_decoder.h

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,73 @@
77
#include "scorer.h"
88
#include "output.h"
99
#include "alphabet.h"
10+
#include "decoderstate.h"
1011

11-
/* CTC Beam Search Decoder
12+
/* Initialize CTC beam search decoder
13+
14+
* Parameters:
15+
* alphabet: The alphabet.
16+
* class_dim: Alphabet length (plus 1 for space character).
17+
* ext_scorer: External scorer to evaluate a prefix, which consists of
18+
* n-gram language model scoring and word insertion term.
19+
* Default null, decoding the input sample without scorer.
20+
* Return:
21+
* A struct containing prefixes and state variables.
22+
*/
23+
DecoderState* decoder_init(const Alphabet &alphabet,
24+
int class_dim,
25+
Scorer *ext_scorer);
26+
27+
/* Send data to the decoder
28+
29+
* Parameters:
30+
* probs: 2-D vector where each element is a vector of probabilities
31+
* over alphabet of one time step.
32+
* alphabet: The alphabet.
33+
* state: The state structure previously obtained from decoder_init().
34+
* time_dim: Number of timesteps.
35+
* class_dim: Alphabet length (plus 1 for space character).
36+
* cutoff_prob: Cutoff probability for pruning.
37+
* cutoff_top_n: Cutoff number for pruning.
38+
* beam_size: The width of beam search.
39+
* ext_scorer: External scorer to evaluate a prefix, which consists of
40+
* n-gram language model scoring and word insertion term.
41+
* Default null, decoding the input sample without scorer.
42+
*/
43+
void decoder_next(const double *probs,
44+
const Alphabet &alphabet,
45+
DecoderState *state,
46+
int time_dim,
47+
int class_dim,
48+
double cutoff_prob,
49+
size_t cutoff_top_n,
50+
size_t beam_size,
51+
Scorer *ext_scorer);
52+
53+
/* Get transcription for the data you sent via decoder_next()
54+
55+
* Parameters:
56+
* state: The state structure previously obtained from decoder_init().
57+
* alphabet: The alphabet.
58+
* beam_size: The width of beam search.
59+
* ext_scorer: External scorer to evaluate a prefix, which consists of
60+
* n-gram language model scoring and word insertion term.
61+
* Default null, decoding the input sample without scorer.
62+
* Return:
63+
* A vector where each element is a pair of score and decoding result,
64+
* in descending order.
65+
*/
66+
std::vector<Output> decoder_decode(DecoderState *state,
67+
const Alphabet &alphabet,
68+
size_t beam_size,
69+
Scorer* ext_scorer);
1270

71+
/* CTC Beam Search Decoder
1372
* Parameters:
14-
* probs_seq: 2-D vector that each element is a vector of probabilities
15-
* over alphabet of one time step.
73+
* probs: 2-D vector where each element is a vector of probabilities
74+
* over alphabet of one time step.
75+
* time_dim: Number of timesteps.
76+
* class_dim: Alphabet length (plus 1 for space character).
1677
* alphabet: The alphabet.
1778
* beam_size: The width of beam search.
1879
* cutoff_prob: Cutoff probability for pruning.
@@ -21,8 +82,8 @@
2182
* n-gram language model scoring and word insertion term.
2283
* Default null, decoding the input sample without scorer.
2384
* Return:
24-
* A vector that each element is a pair of score and decoding result,
25-
* in desending order.
85+
* A vector where each element is a pair of score and decoding result,
86+
* in descending order.
2687
*/
2788

2889
std::vector<Output> ctc_beam_search_decoder(
@@ -36,9 +97,8 @@ std::vector<Output> ctc_beam_search_decoder(
3697
Scorer *ext_scorer);
3798

3899
/* CTC Beam Search Decoder for batch data
39-
40100
* Parameters:
41-
* probs_seq: 3-D vector that each element is a 2-D vector that can be used
101+
* probs: 3-D vector where each element is a 2-D vector that can be used
42102
* by ctc_beam_search_decoder().
43103
* alphabet: The alphabet.
44104
* beam_size: The width of beam search.
@@ -49,7 +109,7 @@ std::vector<Output> ctc_beam_search_decoder(
49109
* n-gram language model scoring and word insertion term.
50110
* Default null, decoding the input sample without scorer.
51111
* Return:
52-
* A 2-D vector that each element is a vector of beam search decoding
112+
* A 2-D vector where each element is a vector of beam search decoding
53113
* result for one audio sample.
54114
*/
55115
std::vector<std::vector<Output>>
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef DECODERSTATE_H_
2+
#define DECODERSTATE_H_
3+
4+
#include <vector>
5+
6+
/* Struct for the state of the decoder, containing the prefixes and initial root prefix plus state variables. */
7+
8+
struct DecoderState {
9+
int space_id;
10+
int blank_id;
11+
std::vector<PathTrie*> prefixes;
12+
PathTrie *prefix_root;
13+
14+
~DecoderState() {
15+
if (prefix_root != nullptr) {
16+
delete prefix_root;
17+
}
18+
prefix_root = nullptr;
19+
}
20+
};
21+
22+
#endif // DECODERSTATE_H_

0 commit comments

Comments
 (0)