@@ -120,8 +120,7 @@ class VadIterator
120120 void reset_states ()
121121 {
122122 // Call reset before each audio start
123- std::memset (_h.data (), 0 .0f , _h.size () * sizeof (float ));
124- std::memset (_c.data (), 0 .0f , _c.size () * sizeof (float ));
123+ std::memset (_state.data (), 0 .0f , _state.size () * sizeof (float ));
125124 triggered = false ;
126125 temp_end = 0 ;
127126 current_sample = 0 ;
@@ -139,19 +138,16 @@ class VadIterator
139138 input.assign (data.begin (), data.end ());
140139 Ort::Value input_ort = Ort::Value::CreateTensor<float >(
141140 memory_info, input.data (), input.size (), input_node_dims, 2 );
141+ Ort::Value state_ort = Ort::Value::CreateTensor<float >(
142+ memory_info, _state.data (), _state.size (), state_node_dims, 3 );
142143 Ort::Value sr_ort = Ort::Value::CreateTensor<int64_t >(
143144 memory_info, sr.data (), sr.size (), sr_node_dims, 1 );
144- Ort::Value h_ort = Ort::Value::CreateTensor<float >(
145- memory_info, _h.data (), _h.size (), hc_node_dims, 3 );
146- Ort::Value c_ort = Ort::Value::CreateTensor<float >(
147- memory_info, _c.data (), _c.size (), hc_node_dims, 3 );
148145
149146 // Clear and add inputs
150147 ort_inputs.clear ();
151148 ort_inputs.emplace_back (std::move (input_ort));
149+ ort_inputs.emplace_back (std::move (state_ort));
152150 ort_inputs.emplace_back (std::move (sr_ort));
153- ort_inputs.emplace_back (std::move (h_ort));
154- ort_inputs.emplace_back (std::move (c_ort));
155151
156152 // Infer
157153 ort_outputs = session->Run (
@@ -161,10 +157,8 @@ class VadIterator
161157
162158 // Output probability & update h,c recursively
163159 float speech_prob = ort_outputs[0 ].GetTensorMutableData <float >()[0 ];
164- float *hn = ort_outputs[1 ].GetTensorMutableData <float >();
165- std::memcpy (_h.data (), hn, size_hc * sizeof (float ));
166- float *cn = ort_outputs[2 ].GetTensorMutableData <float >();
167- std::memcpy (_c.data (), cn, size_hc * sizeof (float ));
160+ float *stateN = ort_outputs[1 ].GetTensorMutableData <float >();
161+ std::memcpy (_state.data (), stateN, size_state * sizeof (float ));
168162
169163 // Push forward sample index
170164 current_sample += window_size_samples;
@@ -376,27 +370,26 @@ class VadIterator
376370 // Inputs
377371 std::vector<Ort::Value> ort_inputs;
378372
379- std::vector<const char *> input_node_names = {" input" , " sr " , " h " , " c " };
373+ std::vector<const char *> input_node_names = {" input" , " state " , " sr " };
380374 std::vector<float > input;
375+ unsigned int size_state = 2 * 1 * 128 ; // It's FIXED.
376+ std::vector<float > _state;
381377 std::vector<int64_t > sr;
382- unsigned int size_hc = 2 * 1 * 64 ; // It's FIXED.
383- std::vector<float > _h;
384- std::vector<float > _c;
385378
386- int64_t input_node_dims[2 ] = {};
379+ int64_t input_node_dims[2 ] = {};
380+ const int64_t state_node_dims[3 ] = {2 , 1 , 128 };
387381 const int64_t sr_node_dims[1 ] = {1 };
388- const int64_t hc_node_dims[3 ] = {2 , 1 , 64 };
389382
390383 // Outputs
391384 std::vector<Ort::Value> ort_outputs;
392- std::vector<const char *> output_node_names = {" output" , " hn " , " cn " };
385+ std::vector<const char *> output_node_names = {" output" , " stateN " };
393386
394387public:
395388 // Construction
396389 VadIterator (const std::wstring ModelPath,
397- int Sample_rate = 16000 , int windows_frame_size = 64 ,
390+ int Sample_rate = 16000 , int windows_frame_size = 32 ,
398391 float Threshold = 0.5 , int min_silence_duration_ms = 0 ,
399- int speech_pad_ms = 64 , int min_speech_duration_ms = 64 ,
392+ int speech_pad_ms = 32 , int min_speech_duration_ms = 32 ,
400393 float max_speech_duration_s = std::numeric_limits<float >::infinity())
401394 {
402395 init_onnx_model (ModelPath);
@@ -422,8 +415,7 @@ class VadIterator
422415 input_node_dims[0 ] = 1 ;
423416 input_node_dims[1 ] = window_size_samples;
424417
425- _h.resize (size_hc);
426- _c.resize (size_hc);
418+ _state.resize (size_state);
427419 sr.resize (1 );
428420 sr[0 ] = sample_rate;
429421 };
0 commit comments