@@ -201,7 +201,10 @@ def get_speech_timestamps(audio: torch.Tensor,
201201 visualize_probs : bool = False ,
202202 progress_tracking_callback : Callable [[float ], None ] = None ,
203203 neg_threshold : float = None ,
204- window_size_samples : int = 512 ,):
204+ window_size_samples : int = 512 ,
205+ hop_size_ratio : float = 1 ,
206+ min_silence_at_max_speech : float = 98 ,
207+ use_max_poss_sil_at_max_speech : bool = True ):
205208
206209 """
207210 This method is used for splitting long audios into speech chunks using silero VAD
@@ -251,13 +254,16 @@ def get_speech_timestamps(audio: torch.Tensor,
251254
252255 window_size_samples: int (default - 512 samples)
253256 !!! DEPRECATED, DOES NOTHING !!!
257+
258+ hop_size_ratio: float (default - 1), number of samples by which the window is shifted, 1 means hop_size_samples = window_size_samples
259+ min_silence_at_max_speech: float (default - 25ms), minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
260+ use_max_poss_sil_at_max_speech: bool (default - True), whether to use the maximum possible silence at max_speech_duration_s or not. If not, the last silence is used.
254261
255262 Returns
256263 ----------
257264 speeches: list of dicts
258265 list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
259266 """
260-
261267 if not torch .is_tensor (audio ):
262268 try :
263269 audio = torch .Tensor (audio )
@@ -282,25 +288,29 @@ def get_speech_timestamps(audio: torch.Tensor,
282288 raise ValueError ("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates" )
283289
284290 window_size_samples = 512 if sampling_rate == 16000 else 256
291+ hop_size_samples = int (window_size_samples * hop_size_ratio )
285292
286293 model .reset_states ()
287294 min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
288295 speech_pad_samples = sampling_rate * speech_pad_ms / 1000
289296 max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
290297 min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
291- min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
298+ min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000
292299
293300 audio_length_samples = len (audio )
294301
295302 speech_probs = []
296- for current_start_sample in range (0 , audio_length_samples , window_size_samples ):
303+ for current_start_sample in range (0 , audio_length_samples , hop_size_samples ):
297304 chunk = audio [current_start_sample : current_start_sample + window_size_samples ]
298305 if len (chunk ) < window_size_samples :
299306 chunk = torch .nn .functional .pad (chunk , (0 , int (window_size_samples - len (chunk ))))
300- speech_prob = model (chunk , sampling_rate ).item ()
307+ try :
308+ speech_prob = model (chunk , sampling_rate ).item ()
309+ except Exception as e :
310+ import ipdb ; ipdb .set_trace ()
301311 speech_probs .append (speech_prob )
302312 # caculate progress and seng it to callback function
303- progress = current_start_sample + window_size_samples
313+ progress = current_start_sample + hop_size_samples
304314 if progress > audio_length_samples :
305315 progress = audio_length_samples
306316 progress_percent = (progress / audio_length_samples ) * 100
@@ -315,42 +325,56 @@ def get_speech_timestamps(audio: torch.Tensor,
315325 neg_threshold = max (threshold - 0.15 , 0.01 )
316326 temp_end = 0 # to save potential segment end (and tolerate some silence)
317327 prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
318-
328+ possible_ends = []
329+
319330 for i , speech_prob in enumerate (speech_probs ):
320331 if (speech_prob >= threshold ) and temp_end :
321- temp_end = 0
332+ if temp_end != 0 :
333+ sil_dur = (hop_size_samples * i ) - temp_end
334+ if sil_dur > min_silence_samples_at_max_speech :
335+ possible_ends .append ((temp_end , sil_dur ))
336+ temp_end = 0
322337 if next_start < prev_end :
323- next_start = window_size_samples * i
338+ next_start = hop_size_samples * i
324339
325340 if (speech_prob >= threshold ) and not triggered :
326341 triggered = True
327- current_speech ['start' ] = window_size_samples * i
342+ current_speech ['start' ] = hop_size_samples * i
328343 continue
329344
330- if triggered and (window_size_samples * i ) - current_speech ['start' ] > max_speech_samples :
331- if prev_end :
345+ if triggered and (hop_size_samples * i ) - current_speech ['start' ] > max_speech_samples :
346+ if possible_ends :
347+ if use_max_poss_sil_at_max_speech :
348+ prev_end , dur = max (possible_ends , key = lambda x : x [1 ]) # use the longest possible silence segment in the current speech chunk
349+ else :
350+ prev_end , dur = possible_ends [- 1 ] # use the last possible silence segement
332351 current_speech ['end' ] = prev_end
333352 speeches .append (current_speech )
334353 current_speech = {}
335- if next_start < prev_end : # previously reached silence (< neg_thres) and is still not speech (< thres)
336- triggered = False
337- else :
354+ next_start = prev_end + dur
355+ if next_start < prev_end + hop_size_samples * i : # previously reached silence (< neg_thres) and is still not speech (< thres)
356+ #triggered = False
338357 current_speech ['start' ] = next_start
358+ else :
359+ triggered = False
360+ #current_speech['start'] = next_start
339361 prev_end = next_start = temp_end = 0
362+ possible_ends = []
340363 else :
341- current_speech ['end' ] = window_size_samples * i
364+ current_speech ['end' ] = hop_size_samples * i
342365 speeches .append (current_speech )
343366 current_speech = {}
344367 prev_end = next_start = temp_end = 0
345368 triggered = False
369+ possible_ends = []
346370 continue
347371
348372 if (speech_prob < neg_threshold ) and triggered :
349373 if not temp_end :
350- temp_end = window_size_samples * i
351- if ((window_size_samples * i ) - temp_end ) > min_silence_samples_at_max_speech : # condition to avoid cutting in very short silence
352- prev_end = temp_end
353- if (window_size_samples * i ) - temp_end < min_silence_samples :
374+ temp_end = hop_size_samples * i
375+ # if ((hop_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
376+ # prev_end = temp_end
377+ if (hop_size_samples * i ) - temp_end < min_silence_samples :
354378 continue
355379 else :
356380 current_speech ['end' ] = temp_end
@@ -359,6 +383,7 @@ def get_speech_timestamps(audio: torch.Tensor,
359383 current_speech = {}
360384 prev_end = next_start = temp_end = 0
361385 triggered = False
386+ possible_ends = []
362387 continue
363388
364389 if current_speech and (audio_length_samples - current_speech ['start' ]) > min_speech_samples :
@@ -390,7 +415,7 @@ def get_speech_timestamps(audio: torch.Tensor,
390415 speech_dict ['end' ] *= step
391416
392417 if visualize_probs :
393- make_visualization (speech_probs , window_size_samples / sampling_rate )
418+ make_visualization (speech_probs , hop_size_samples / sampling_rate )
394419
395420 return speeches
396421
0 commit comments