Skip to content

Commit 34dea51

Browse files
authored
Merge pull request #664 from shashank14k/master
Adding additional params to get_speech_timestamps
2 parents 51fd431 + bbf22a0 commit 34dea51

File tree

1 file changed

+46
-21
lines changed

1 file changed

+46
-21
lines changed

src/silero_vad/utils_vad.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,10 @@ def get_speech_timestamps(audio: torch.Tensor,
201201
visualize_probs: bool = False,
202202
progress_tracking_callback: Callable[[float], None] = None,
203203
neg_threshold: float = None,
204-
window_size_samples: int = 512,):
204+
window_size_samples: int = 512,
205+
hop_size_ratio: float = 1,
206+
min_silence_at_max_speech: float = 98,
207+
use_max_poss_sil_at_max_speech: bool = True):
205208

206209
"""
207210
This method is used for splitting long audios into speech chunks using silero VAD
@@ -251,13 +254,16 @@ def get_speech_timestamps(audio: torch.Tensor,
251254
252255
window_size_samples: int (default - 512 samples)
253256
!!! DEPRECATED, DOES NOTHING !!!
257+
258+
hop_size_ratio: float (default - 1), number of samples by which the window is shifted, 1 means hop_size_samples = window_size_samples
259+
min_silence_at_max_speech: float (default - 25ms), minimum silence duration in ms which is used to avoid abrupt cuts when max_speech_duration_s is reached
260+
use_max_poss_sil_at_max_speech: bool (default - True), whether to use the maximum possible silence at max_speech_duration_s or not. If not, the last silence is used.
254261
255262
Returns
256263
----------
257264
speeches: list of dicts
258265
list containing ends and beginnings of speech chunks (samples or seconds based on return_seconds)
259266
"""
260-
261267
if not torch.is_tensor(audio):
262268
try:
263269
audio = torch.Tensor(audio)
@@ -282,25 +288,29 @@ def get_speech_timestamps(audio: torch.Tensor,
282288
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")
283289

284290
window_size_samples = 512 if sampling_rate == 16000 else 256
291+
hop_size_samples = int(window_size_samples * hop_size_ratio)
285292

286293
model.reset_states()
287294
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
288295
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
289296
max_speech_samples = sampling_rate * max_speech_duration_s - window_size_samples - 2 * speech_pad_samples
290297
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
291-
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
298+
min_silence_samples_at_max_speech = sampling_rate * min_silence_at_max_speech / 1000
292299

293300
audio_length_samples = len(audio)
294301

295302
speech_probs = []
296-
for current_start_sample in range(0, audio_length_samples, window_size_samples):
303+
for current_start_sample in range(0, audio_length_samples, hop_size_samples):
297304
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
298305
if len(chunk) < window_size_samples:
299306
chunk = torch.nn.functional.pad(chunk, (0, int(window_size_samples - len(chunk))))
300-
speech_prob = model(chunk, sampling_rate).item()
307+
try:
308+
speech_prob = model(chunk, sampling_rate).item()
309+
except Exception as e:
310+
import ipdb; ipdb.set_trace()
301311
speech_probs.append(speech_prob)
302312
# caculate progress and seng it to callback function
303-
progress = current_start_sample + window_size_samples
313+
progress = current_start_sample + hop_size_samples
304314
if progress > audio_length_samples:
305315
progress = audio_length_samples
306316
progress_percent = (progress / audio_length_samples) * 100
@@ -315,42 +325,56 @@ def get_speech_timestamps(audio: torch.Tensor,
315325
neg_threshold = max(threshold - 0.15, 0.01)
316326
temp_end = 0 # to save potential segment end (and tolerate some silence)
317327
prev_end = next_start = 0 # to save potential segment limits in case of maximum segment size reached
318-
328+
possible_ends = []
329+
319330
for i, speech_prob in enumerate(speech_probs):
320331
if (speech_prob >= threshold) and temp_end:
321-
temp_end = 0
332+
if temp_end != 0:
333+
sil_dur = (hop_size_samples * i) - temp_end
334+
if sil_dur > min_silence_samples_at_max_speech:
335+
possible_ends.append((temp_end, sil_dur))
336+
temp_end = 0
322337
if next_start < prev_end:
323-
next_start = window_size_samples * i
338+
next_start = hop_size_samples * i
324339

325340
if (speech_prob >= threshold) and not triggered:
326341
triggered = True
327-
current_speech['start'] = window_size_samples * i
342+
current_speech['start'] = hop_size_samples * i
328343
continue
329344

330-
if triggered and (window_size_samples * i) - current_speech['start'] > max_speech_samples:
331-
if prev_end:
345+
if triggered and (hop_size_samples * i) - current_speech['start'] > max_speech_samples:
346+
if possible_ends:
347+
if use_max_poss_sil_at_max_speech:
348+
prev_end, dur = max(possible_ends, key=lambda x: x[1]) # use the longest possible silence segment in the current speech chunk
349+
else:
350+
prev_end, dur = possible_ends[-1] # use the last possible silence segement
332351
current_speech['end'] = prev_end
333352
speeches.append(current_speech)
334353
current_speech = {}
335-
if next_start < prev_end: # previously reached silence (< neg_thres) and is still not speech (< thres)
336-
triggered = False
337-
else:
354+
next_start = prev_end + dur
355+
if next_start < prev_end + hop_size_samples * i: # previously reached silence (< neg_thres) and is still not speech (< thres)
356+
#triggered = False
338357
current_speech['start'] = next_start
358+
else:
359+
triggered = False
360+
#current_speech['start'] = next_start
339361
prev_end = next_start = temp_end = 0
362+
possible_ends = []
340363
else:
341-
current_speech['end'] = window_size_samples * i
364+
current_speech['end'] = hop_size_samples * i
342365
speeches.append(current_speech)
343366
current_speech = {}
344367
prev_end = next_start = temp_end = 0
345368
triggered = False
369+
possible_ends = []
346370
continue
347371

348372
if (speech_prob < neg_threshold) and triggered:
349373
if not temp_end:
350-
temp_end = window_size_samples * i
351-
if ((window_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
352-
prev_end = temp_end
353-
if (window_size_samples * i) - temp_end < min_silence_samples:
374+
temp_end = hop_size_samples * i
375+
# if ((hop_size_samples * i) - temp_end) > min_silence_samples_at_max_speech: # condition to avoid cutting in very short silence
376+
# prev_end = temp_end
377+
if (hop_size_samples * i) - temp_end < min_silence_samples:
354378
continue
355379
else:
356380
current_speech['end'] = temp_end
@@ -359,6 +383,7 @@ def get_speech_timestamps(audio: torch.Tensor,
359383
current_speech = {}
360384
prev_end = next_start = temp_end = 0
361385
triggered = False
386+
possible_ends = []
362387
continue
363388

364389
if current_speech and (audio_length_samples - current_speech['start']) > min_speech_samples:
@@ -390,7 +415,7 @@ def get_speech_timestamps(audio: torch.Tensor,
390415
speech_dict['end'] *= step
391416

392417
if visualize_probs:
393-
make_visualization(speech_probs, window_size_samples / sampling_rate)
418+
make_visualization(speech_probs, hop_size_samples / sampling_rate)
394419

395420
return speeches
396421

0 commit comments

Comments
 (0)