Skip to content

Commit cd9232e

Browse files
committed
updated language compatibility
-updated `language` parameter to ignore case and accept language labels and codes
1 parent fe15241 commit cd9232e

File tree

1 file changed

+35
-5
lines changed

1 file changed

+35
-5
lines changed

stable_whisper/whisper_compatibility.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,40 @@ def warn_compatibility_issues(
270270
warnings.warn(compatibility_warning)
271271

272272

273+
def get_valid_language(language: str, is_faster_model: bool, model=None):
274+
if language is None:
275+
if model is None:
276+
return language
277+
if is_faster_model:
278+
return model.supported_languages[0] if len(model.supported_languages) == 1 else language
279+
return language if model.is_multilingual else 'en'
280+
281+
if is_faster_model:
282+
from faster_whisper.tokenizer import _LANGUAGE_CODES
283+
if language in _LANGUAGE_CODES:
284+
return language
285+
faster_language_code_lower = {code.lower(): code for code in _LANGUAGE_CODES}
286+
if language.lower() in faster_language_code_lower:
287+
return faster_language_code_lower[language.lower()]
288+
for k, v in LANGUAGES.items():
289+
if v.lower() == language.lower() and k.lower() in faster_language_code_lower:
290+
return faster_language_code_lower[k.lower()]
291+
292+
raise ValueError(f'{language} is not a valid language or language code. '
293+
f'Available languages: {tuple(_LANGUAGE_CODES.keys())}')
294+
else:
295+
if language in LANGUAGES:
296+
return language
297+
language_codes_lower = {code.lower(): code for code in LANGUAGES}
298+
if language.lower() in language_codes_lower:
299+
return language_codes_lower[language.lower()]
300+
for k, v in LANGUAGES.items():
301+
if v.lower() == language.lower():
302+
return k
303+
raise ValueError(f'{language} is not a valid language or language code. '
304+
f'Available languages: {tuple(LANGUAGES.keys())}')
305+
306+
273307
def get_tokenizer(model=None, is_faster_model: bool = False, **kwargs):
274308
"""
275309
Backward compatible wrapper of :func:`whisper.tokenizer.get_tokenizer` and
@@ -282,11 +316,6 @@ def get_tokenizer(model=None, is_faster_model: bool = False, **kwargs):
282316
params = get_func_parameters(tokenizer)
283317
if model is not None and 'tokenizer' not in kwargs:
284318
kwargs['tokenizer'] = model.hf_tokenizer
285-
if 'language' in kwargs and kwargs['language'] not in _LANGUAGE_CODES:
286-
for k, v in LANGUAGES.items():
287-
if v == kwargs['language'] and k in _LANGUAGE_CODES:
288-
kwargs['language'] = k
289-
break
290319
else:
291320
tokenizer = whisper.tokenizer.get_tokenizer
292321
params = _TOKENIZER_PARAMS
@@ -299,4 +328,5 @@ def get_tokenizer(model=None, is_faster_model: bool = False, **kwargs):
299328
(model.num_languages if hasattr(model, 'num_languages') else model.model.num_languages)
300329
elif 'num_languages' in kwargs:
301330
del kwargs['num_languages']
331+
kwargs['language'] = get_valid_language(kwargs.get('language'))
302332
return tokenizer(**kwargs)

0 commit comments

Comments
 (0)