Skip to content

Commit 78a223f

Browse files
committed
fixed refine() for Faster-Whisper and HF models
-fixed `refine()` failing when word tokens are missing in input `result` (i.e. `transcribe()` outputs from Faster-Whisper and HF models) -fixed refine()` failing when word probabilities are missing in input `result` (i.e. `transcribe()` outputs from HF models) -fixed incorrect description of alignment and refinement support for Faster-Whisper models in README.md -updated HF model transcription to always return detected language in its result
1 parent 751b041 commit 78a223f

File tree

3 files changed

+43
-25
lines changed

3 files changed

+43
-25
lines changed

README.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -376,19 +376,24 @@ Use with [Faster-Whisper](https://github.com/guillaumekln/faster-whisper):
376376
```
377377
pip install -U stable-ts[fw]
378378
```
379-
* [Refinement](#refinement) is not supported on Faster-Whisper models
380-
* [Alignment](#alignment) is slower on Faster-Whisper models than on vanilla models (i.e. ones loaded with `stable_whisper.load_model()`)
379+
* [Refinement](#refinement) is slower on Faster-Whisper models than on vanilla models (i.e. ones loaded with `stable_whisper.load_model()`)
381380
```python
382381
model = stable_whisper.load_faster_whisper('base')
383-
result = model.transcribe_stable('audio.mp3')
384-
385-
# For version 2.18.0+:
386382
result = model.transcribe('audio.mp3')
383+
384+
# For versions < 2.18.0:
385+
result = model.transcribe_stable('audio.mp3')
387386
```
388-
Note: `model.transcribe_stable()` is deprecated in 2.18.0 and will be removed in future versions.
387+
388+
<details>
389+
<summary>CLI</summary>
390+
389391
```commandline
390392
stable-ts audio.mp3 -o audio.srt -fw
391393
```
394+
395+
</details>
396+
392397
Docstring:
393398
<details>
394399
<summary>load_faster_whisper()</summary>

stable_whisper/alignment.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -547,19 +547,18 @@ def refine(
547547
Saved 'audio.srt'
548548
"""
549549
model = as_vanilla(model)
550-
if result:
551-
if not result.has_words:
552-
if not result.language:
553-
raise RuntimeError(f'cannot add word-timestamps to result with missing language')
554-
align_words(model, audio, result)
555-
elif not all(word.tokens for word in result.all_words()):
556-
tokenizer = get_tokenizer(model)
557-
for word in result.all_words():
558-
word.tokens = tokenizer.encode(word.word)
559-
tokenizer = get_tokenizer(model, language=result.language, task='transcribe')
550+
is_faster_model = model.__module__.startswith('faster_whisper.')
551+
if result and (not result.has_words or any(word.probability is None for word in result.all_words())):
552+
if not result.language:
553+
raise RuntimeError(f'cannot align words with result missing language')
554+
align_words(model, audio, result)
555+
tokenizer = get_tokenizer(model, is_faster_model=is_faster_model, language=result.language, task='transcribe')
556+
if result and not all(word.tokens for word in result.all_words()):
557+
for word in result.all_words():
558+
word.tokens = tokenizer.encode(word.word)
560559

561560
options = AllOptions(options, post=False, silence=False, align=False)
562-
model_type = 'fw' if (is_faster_model := model.__module__.startswith('faster_whisper.')) else None
561+
model_type = 'fw' if is_faster_model else None
563562
inference_func = get_whisper_refinement_func(model, tokenizer, model_type, single_batch)
564563
max_inference_tokens = (model.max_length if is_faster_model else model.dims.n_text_ctx) - 6
565564

stable_whisper/whisper_word_level/hf_whisper.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,15 @@ def _inner_transcribe(
140140
print(f'Transcribing with Hugging Face Whisper ({self.model_name})...')
141141
pipe_kwargs = dict(
142142
generate_kwargs=generate_kwargs,
143-
return_timestamps='word' if word_timestamps else True
143+
return_timestamps='word' if word_timestamps else True,
144+
return_language=True
144145
)
145146
if batch_size is not None:
146147
pipe_kwargs['batch_size'] = batch_size
147148
output = self._pipe(audio, **pipe_kwargs)
148149
result = output['chunks']
150+
if not language and result and 'language' in result[0]:
151+
language = result[0]['language']
149152
if verbose is not None:
150153
print(f'Transcription completed.')
151154

@@ -200,13 +203,24 @@ def _curr_max_end(start: float, next_idx: float) -> float:
200203
for word in result
201204
]
202205
replace_none_ts(words)
203-
return [words]
204-
segs = [
205-
dict(start=seg['timestamp'][0], end=seg['timestamp'][1], text=seg['text'])
206-
for seg in result
207-
]
208-
replace_none_ts(segs)
209-
return segs
206+
if words:
207+
segs = [
208+
dict(
209+
start=words[0]['start'],
210+
end=words[-1]['end'],
211+
text=''.join(w['word'] for w in words),
212+
words=words
213+
)
214+
]
215+
else:
216+
segs = []
217+
else:
218+
segs = [
219+
dict(start=seg['timestamp'][0], end=seg['timestamp'][1], text=seg['text'])
220+
for seg in result
221+
]
222+
replace_none_ts(segs)
223+
return dict(segments=segs, language=language)
210224

211225
def transcribe(
212226
self,

0 commit comments

Comments
 (0)