Skip to content

Commit cce0e4a

Browse files
authored
Merge pull request #244 from Khadija-Bayoud/master
Fix issues with custom Whisper models (transformers backend)
2 parents cd16890 + b25bfa2 commit cce0e4a

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
Cython
22
dtw-python
3-
openai-whisper
3+
openai-whisper

whisper_timestamped/transcribe.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,7 +2439,7 @@ def load_model(
24392439
name = f"openai/whisper-{name}"
24402440
# TODO: use download_root
24412441
# TODO: does in_memory makes sense?
2442-
cache_dir=os.path.join(download_root, "huggingface", "hub") if download_root else None,
2442+
cache_dir=os.path.join(download_root, "huggingface", "hub") if download_root else None
24432443
try:
24442444
generation_config = transformers.GenerationConfig.from_pretrained(name, cache_dir=cache_dir)
24452445
except OSError:
@@ -2687,7 +2687,7 @@ def transcribe(self, audio, use_token_timestamps=False, **kwargs):
26872687
return_segments = True,
26882688
return_timestamps = True,
26892689
return_token_timestamps = use_token_timestamps,
2690-
max_length = self.dims.n_text_ctx,
2690+
max_length = self.dims.n_text_ctx if self.dims.n_text_ctx is not None else generation_config.max_length,
26912691
is_multilingual = self.is_multilingual,
26922692
prompt_ids = prompt_ids,
26932693
generation_config = generation_config,
@@ -2735,8 +2735,16 @@ def transcribe(self, audio, use_token_timestamps=False, **kwargs):
27352735
i_sot = -1
27362736
if self.is_multilingual:
27372737
language = self.tokenizer.decode([first_segment_tokens[i_sot+1]], decode_with_timestamps=True)
2738-
assert len(language) in [6,7], f"Unexpected language detected: '{language}' ({first_segment_tokens[i_sot+1]}) in '{self.tokenizer.decode(first_segment_tokens, decode_with_timestamps=True)}'"
2739-
language = language[2:-2]
2738+
2739+
if len(language) in (6,7) and language.startswith("<|") and language.endswith("|>"):
2740+
language = language[2:-2]
2741+
else:
2742+
logging.debug(
2743+
f"Unexpected language detected: '{language}' "
2744+
f"({first_segment_tokens[i_sot+1]}) in "
2745+
f"'{self.tokenizer.decode(first_segment_tokens, decode_with_timestamps=True)}'"
2746+
)
2747+
language = None
27402748
else:
27412749
language = "en"
27422750

0 commit comments

Comments
 (0)