1414
1515using FSTMATCH = fst::SortedMatcher<fst::StdVectorFst>;
1616
17- std::vector<Output> ctc_beam_search_decoder (
18- const double *probs,
19- int time_dim,
20- int class_dim,
21- const Alphabet &alphabet,
22- size_t beam_size,
23- double cutoff_prob,
24- size_t cutoff_top_n,
25- Scorer *ext_scorer) {
17+ DecoderState* decoder_init (const Alphabet &alphabet,
18+ int class_dim,
19+ Scorer* ext_scorer) {
20+
2621 // dimension check
2722 VALID_CHECK_EQ (class_dim, alphabet.GetSize ()+1 ,
2823 " The shape of probs does not match with "
2924 " the shape of the vocabulary" );
3025
3126 // assign special ids
32- int space_id = alphabet.GetSpaceLabel ();
33- int blank_id = alphabet.GetSize ();
27+ DecoderState *state = new DecoderState;
28+ state->space_id = alphabet.GetSpaceLabel ();
29+ state->blank_id = alphabet.GetSize ();
3430
3531 // init prefixes' root
36- PathTrie root;
37- root.score = root.log_prob_b_prev = 0.0 ;
38- std::vector<PathTrie *> prefixes;
39- prefixes.push_back (&root);
32+ PathTrie *root = new PathTrie;
33+ root->score = root->log_prob_b_prev = 0.0 ;
34+
35+ state->prefix_root = root;
36+
37+ state->prefixes .push_back (root);
4038
4139 if (ext_scorer != nullptr && !ext_scorer->is_character_based ()) {
4240 auto dict_ptr = ext_scorer->dictionary ->Copy (true );
43- root. set_dictionary (dict_ptr);
41+ root-> set_dictionary (dict_ptr);
4442 auto matcher = std::make_shared<FSTMATCH>(*dict_ptr, fst::MATCH_INPUT);
45- root. set_matcher (matcher);
43+ root-> set_matcher (matcher);
4644 }
45+
46+ return state;
47+ }
48+
49+ void decoder_next (const double *probs,
50+ const Alphabet &alphabet,
51+ DecoderState *state,
52+ int time_dim,
53+ int class_dim,
54+ double cutoff_prob,
55+ size_t cutoff_top_n,
56+ size_t beam_size,
57+ Scorer *ext_scorer) {
4758
48- // prefix search over time
59+ // prefix search over time
4960 for (size_t time_step = 0 ; time_step < time_dim; ++time_step) {
5061 auto *prob = &probs[time_step*class_dim];
5162
5263 float min_cutoff = -NUM_FLT_INF;
5364 bool full_beam = false ;
5465 if (ext_scorer != nullptr ) {
55- size_t num_prefixes = std::min (prefixes.size (), beam_size);
66+ size_t num_prefixes = std::min (state-> prefixes .size (), beam_size);
5667 std::sort (
57- prefixes.begin (), prefixes.begin () + num_prefixes, prefix_compare);
58- min_cutoff = prefixes[num_prefixes - 1 ]->score +
59- std::log (prob[blank_id]) - std::max (0.0 , ext_scorer->beta );
68+ state->prefixes .begin (), state->prefixes .begin () + num_prefixes, prefix_compare);
69+
70+ min_cutoff = state->prefixes [num_prefixes - 1 ]->score +
71+ std::log (prob[state->blank_id ]) - std::max (0.0 , ext_scorer->beta );
6072 full_beam = (num_prefixes == beam_size);
6173 }
6274
@@ -67,22 +79,25 @@ std::vector<Output> ctc_beam_search_decoder(
6779 auto c = log_prob_idx[index].first ;
6880 auto log_prob_c = log_prob_idx[index].second ;
6981
70- for (size_t i = 0 ; i < prefixes.size () && i < beam_size; ++i) {
71- auto prefix = prefixes[i];
82+ for (size_t i = 0 ; i < state-> prefixes .size () && i < beam_size; ++i) {
83+ auto prefix = state-> prefixes [i];
7284 if (full_beam && log_prob_c + prefix->score < min_cutoff) {
7385 break ;
7486 }
87+
7588 // blank
76- if (c == blank_id) {
89+ if (c == state-> blank_id ) {
7790 prefix->log_prob_b_cur =
7891 log_sum_exp (prefix->log_prob_b_cur , log_prob_c + prefix->score );
7992 continue ;
8093 }
94+
8195 // repeated character
8296 if (c == prefix->character ) {
8397 prefix->log_prob_nb_cur = log_sum_exp (
8498 prefix->log_prob_nb_cur , log_prob_c + prefix->log_prob_nb_prev );
8599 }
100+
86101 // get new prefix
87102 auto prefix_new = prefix->get_path_trie (c, time_step, log_prob_c);
88103
@@ -98,7 +113,7 @@ std::vector<Output> ctc_beam_search_decoder(
98113
99114 // language model scoring
100115 if (ext_scorer != nullptr &&
101- (c == space_id || ext_scorer->is_character_based ())) {
116+ (c == state-> space_id || ext_scorer->is_character_based ())) {
102117 PathTrie *prefix_to_score = nullptr ;
103118 // skip scoring the space
104119 if (ext_scorer->is_character_based ()) {
@@ -114,34 +129,41 @@ std::vector<Output> ctc_beam_search_decoder(
114129 log_p += score;
115130 log_p += ext_scorer->beta ;
116131 }
132+
117133 prefix_new->log_prob_nb_cur =
118134 log_sum_exp (prefix_new->log_prob_nb_cur , log_p);
119135 }
120136 } // end of loop over prefix
121137 } // end of loop over vocabulary
122-
123-
124- prefixes.clear ();
138+
125139 // update log probs
126- root.iterate_to_vec (prefixes);
140+ state->prefixes .clear ();
141+ state->prefix_root ->iterate_to_vec (state->prefixes );
127142
128143 // only preserve top beam_size prefixes
129- if (prefixes.size () >= beam_size) {
130- std::nth_element (prefixes.begin (),
131- prefixes.begin () + beam_size,
132- prefixes.end (),
144+ if (state-> prefixes .size () >= beam_size) {
145+ std::nth_element (state-> prefixes .begin (),
146+ state-> prefixes .begin () + beam_size,
147+ state-> prefixes .end (),
133148 prefix_compare);
134- for (size_t i = beam_size; i < prefixes.size (); ++i) {
135- prefixes[i]->remove ();
149+ for (size_t i = beam_size; i < state-> prefixes .size (); ++i) {
150+ state-> prefixes [i]->remove ();
136151 }
137152 }
153+
138154 } // end of loop over time
155+ }
156+
157+ std::vector<Output> decoder_decode (DecoderState *state,
158+ const Alphabet &alphabet,
159+ size_t beam_size,
160+ Scorer* ext_scorer) {
139161
140162 // score the last word of each prefix that doesn't end with space
141163 if (ext_scorer != nullptr && !ext_scorer->is_character_based ()) {
142- for (size_t i = 0 ; i < beam_size && i < prefixes.size (); ++i) {
143- auto prefix = prefixes[i];
144- if (!prefix->is_empty () && prefix->character != space_id) {
164+ for (size_t i = 0 ; i < beam_size && i < state-> prefixes .size (); ++i) {
165+ auto prefix = state-> prefixes [i];
166+ if (!prefix->is_empty () && prefix->character != state-> space_id ) {
145167 float score = 0.0 ;
146168 std::vector<std::string> ngram = ext_scorer->make_ngram (prefix);
147169 score = ext_scorer->get_log_cond_prob (ngram) * ext_scorer->alpha ;
@@ -151,30 +173,48 @@ std::vector<Output> ctc_beam_search_decoder(
151173 }
152174 }
153175
154- size_t num_prefixes = std::min (prefixes.size (), beam_size);
155- std::sort (prefixes.begin (), prefixes.begin () + num_prefixes, prefix_compare);
176+ size_t num_prefixes = std::min (state-> prefixes .size (), beam_size);
177+ std::sort (state-> prefixes .begin (), state-> prefixes .begin () + num_prefixes, prefix_compare);
156178
157179 // compute aproximate ctc score as the return score, without affecting the
158180 // return order of decoding result. To delete when decoder gets stable.
159- for (size_t i = 0 ; i < beam_size && i < prefixes.size (); ++i) {
160- double approx_ctc = prefixes[i]->score ;
181+ for (size_t i = 0 ; i < beam_size && i < state-> prefixes .size (); ++i) {
182+ double approx_ctc = state-> prefixes [i]->score ;
161183 if (ext_scorer != nullptr ) {
162184 std::vector<int > output;
163185 std::vector<int > timesteps;
164- prefixes[i]->get_path_vec (output, timesteps);
186+ state-> prefixes [i]->get_path_vec (output, timesteps);
165187 auto prefix_length = output.size ();
166188 auto words = ext_scorer->split_labels (output);
167189 // remove word insert
168190 approx_ctc = approx_ctc - prefix_length * ext_scorer->beta ;
169191 // remove language model weight:
170192 approx_ctc -= (ext_scorer->get_sent_log_prob (words)) * ext_scorer->alpha ;
171193 }
172- prefixes[i]->approx_ctc = approx_ctc;
194+ state-> prefixes [i]->approx_ctc = approx_ctc;
173195 }
174196
175- return get_beam_search_result (prefixes, beam_size);
197+ return get_beam_search_result (state-> prefixes , beam_size);
176198}
177199
200+ std::vector<Output> ctc_beam_search_decoder (
201+ const double *probs,
202+ int time_dim,
203+ int class_dim,
204+ const Alphabet &alphabet,
205+ size_t beam_size,
206+ double cutoff_prob,
207+ size_t cutoff_top_n,
208+ Scorer *ext_scorer) {
209+
210+ DecoderState *state = decoder_init (alphabet, class_dim, ext_scorer);
211+ decoder_next (probs, alphabet, state, time_dim, class_dim, cutoff_prob, cutoff_top_n, beam_size, ext_scorer);
212+ std::vector<Output> out = decoder_decode (state, alphabet, beam_size, ext_scorer);
213+
214+ delete state;
215+
216+ return out;
217+ }
178218
179219std::vector<std::vector<Output>>
180220ctc_beam_search_decoder_batch (
0 commit comments