Skip to content

Commit fc0d0da

Browse files
committed
fixed nonspeech_skip
-fixed `nonspeech_skip` causing alignment to skip sections of speech -fixed "'last_ts' referenced before assignment" error for alignment (#429)
1 parent e7ff3dd commit fc0d0da

File tree

1 file changed

+22
-14
lines changed

1 file changed

+22
-14
lines changed

stable_whisper/non_whisper/alignment.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def align(
283283
desc='Align'
284284
) as tqdm_pbar:
285285
result: List[BasicWordTiming] = []
286-
286+
last_ts = 0.0
287287
while self._all_word_tokens:
288288

289289
self._time_offset = self._seek_sample / self.sample_rate
@@ -876,42 +876,50 @@ def _skip_nonspeech(
876876

877877
segment_samples = audio_segment.size(-1)
878878

879+
max_time_offset = self._time_offset + self.options.post.min_word_dur
880+
min_time_offset = self._time_offset - self.options.post.min_word_dur
881+
879882
if (
880-
(segment_nonspeech_timings[0][0] <= self._time_offset + self.options.post.min_word_dur) and
881-
(
882-
segment_nonspeech_timings[1][0]
883-
>=
884-
self._time_offset + segment_samples - self.options.post.min_word_dur
885-
)
883+
(segment_nonspeech_timings[0][0] < max_time_offset) and
884+
(segment_nonspeech_timings[1][0] > min_time_offset + segment_samples)
886885
):
886+
# entire audio segment is within first nonspeech section
887887
self._seek_sample += segment_samples
888888
return
889889

890-
timing_indices = (segment_nonspeech_timings[1] - segment_nonspeech_timings[0]) >= self.nonspeech_skip
891-
if not timing_indices.any():
890+
# mask for valid nonspeech sections (i.e. sections with duration >= ``nonspeech_skip``)
891+
valid_sections = (segment_nonspeech_timings[1] - segment_nonspeech_timings[0]) >= self.nonspeech_skip
892+
if not valid_sections.any():
893+
# no valid nonspeech sections
892894
return audio_segment
893895

894-
nonspeech_starts = segment_nonspeech_timings[0][timing_indices]
895-
nonspeech_ends = segment_nonspeech_timings[1][timing_indices]
896-
897-
if nonspeech_ends[0] <= round(self._time_offset, 3) >= nonspeech_starts[0]:
896+
nonspeech_starts = segment_nonspeech_timings[0, valid_sections]
897+
if max_time_offset < nonspeech_starts[0]:
898+
# current time is before the first valid nonspeech section
898899
return audio_segment
899900

901+
nonspeech_ends = segment_nonspeech_timings[1, valid_sections]
900902
curr_total_samples = self.audio_loader.get_total_samples()
901903

904+
# skip to end of the first nonspeech section
902905
self._seek_sample = round(nonspeech_ends[0] * self.sample_rate)
903-
if self._seek_sample + (self.options.post.min_word_dur * self.sample_rate) >= curr_total_samples:
906+
if self._seek_sample + (self.options.post.min_word_dur * self.sample_rate) > curr_total_samples:
907+
# new time is over total duration of the audio
904908
self._seek_sample = curr_total_samples
905909
return
906910

907911
self._time_offset = self._seek_sample / self.sample_rate
908912

913+
# try to load audio segment from the new timestamp
909914
audio_segment = self.audio_loader.next_chunk(self._seek_sample, self.n_samples)
910915
if audio_segment is None:
916+
# reached eof
911917
return
912918

919+
# recompute nonspeech sections for the new audio segment for later use
913920
self._nonspeech_preds = self.nonspeech_predictor.predict(audio=audio_segment, offset=self._time_offset)
914921
if len(nonspeech_starts) > 1:
922+
# remove all audio samples after start of second valid nonspeech section
915923
new_sample_count = round((nonspeech_starts[1] - nonspeech_ends[0]) * self.sample_rate)
916924
audio_segment = audio_segment[:new_sample_count]
917925

0 commit comments

Comments
 (0)