Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,47 @@ jobs:
pytorch-version: 1.10.1
numpy-requirement: "'numpy<2'"
tokenizers-requirement: "'tokenizers<=0.20.3'"
transformers-requirement: "'transformers==4.46.3'"
- python-version: '3.8'
pytorch-version: 1.13.1
numpy-requirement: "'numpy<2'"
tokenizers-requirement: "'tokenizers<=0.20.3'"
transformers-requirement: "'transformers==4.46.3'"
- python-version: '3.8'
pytorch-version: 2.0.1
numpy-requirement: "'numpy<2'"
tokenizers-requirement: "'tokenizers<=0.20.3'"
transformers-requirement: "'transformers==4.46.3'"
- python-version: '3.9'
pytorch-version: 2.1.2
numpy-requirement: "'numpy<2'"
tokenizers-requirement: "'tokenizers'"
transformers-requirement: "'transformers'"
- python-version: '3.10'
pytorch-version: 2.2.2
numpy-requirement: "'numpy<2'"
tokenizers-requirement: "'tokenizers'"
transformers-requirement: "'transformers'"
- python-version: '3.11'
pytorch-version: 2.3.1
numpy-requirement: "'numpy'"
tokenizers-requirement: "'tokenizers'"
transformers-requirement: "'transformers'"
- python-version: '3.12'
pytorch-version: 2.4.1
numpy-requirement: "'numpy'"
tokenizers-requirement: "'tokenizers'"
transformers-requirement: "'transformers'"
- python-version: '3.12'
pytorch-version: 2.5.0
numpy-requirement: "'numpy'"
tokenizers-requirement: "'tokenizers'"
transformers-requirement: "'transformers'"
- python-version: '3.12'
pytorch-version: 2.6.0
numpy-requirement: "'numpy'"
tokenizers-requirement: "'tokenizers'"
transformers-requirement: "'transformers'"
steps:
- uses: conda-incubator/setup-miniconda@v3
- run: conda install -n test ffmpeg python=${{ matrix.python-version }}
Expand All @@ -63,7 +72,7 @@ jobs:
- run: python test/test_transcribe.py load_faster_whisper
- run: python test/test_align.py load_faster_whisper
- run: python test/test_refine.py load_faster_whisper
- run: pip3 install .["hf"] 'transformers<=4.46.3'
- run: pip3 install .["hf"] ${{ matrix.transformers-requirement }}
- run: python test/test_transcribe.py load_hf_whisper
- run: python test/test_align.py load_hf_whisper
- run: python test/test_refine.py load_hf_whisper
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def read_me() -> str:
"torch",
"torchaudio",
"tqdm",
"openai-whisper>=20230314,<=20240930"
"openai-whisper>=20250625"
],
extras_require={
"fw": [
Expand Down
1 change: 1 addition & 0 deletions stable_whisper/whisper_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'20231117',
'20240927',
'20240930',
'20250625',
)
_required_whisper_ver = _COMPATIBLE_WHISPER_VERSIONS[-1]

Expand Down
63 changes: 52 additions & 11 deletions stable_whisper/whisper_word_level/hf_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,38 +63,39 @@ def get_device(device: str = None) -> str:

def load_hf_pipe(model_name: str, device: str = None, flash: bool = False, **pipeline_kwargs):
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers.configuration_utils import PretrainedConfig
device = get_device(device)
is_cpu = (device if isinstance(device, str) else getattr(device, 'type', None)) == 'cpu'
dtype = torch.float32 if is_cpu or not torch.cuda.is_available() else torch.float16
model_id = HF_MODELS.get(model_name, model_name)

if flash:
config = PretrainedConfig(
attn_implementation="flash_attention_2",
)
else:
config = None

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
use_flash_attention_2=flash
config=config
).to(device)

processor = AutoProcessor.from_pretrained(model_id)

if not flash:
try:
model = model.to_bettertransformer()
except (ValueError, ImportError) as e:
import warnings
warnings.warn(
f'Failed convert model to BetterTransformer due to: {e}'
)

final_pipe_kwargs = dict(
task="automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
# chunk_length_s=30,
torch_dtype=dtype,
device=device,
return_language=True
)
final_pipe_kwargs.update(**pipeline_kwargs)
pipe = pipeline(**final_pipe_kwargs)
Expand All @@ -106,6 +107,7 @@ class WhisperHF:

def __init__(self, model_name: str, device: str = None, flash: bool = False, pipeline=None, **pipeline_kwargs):
self._model_name = model_name
pipeline_kwargs['return_language'] = True
self._pipe = load_hf_pipe(self._model_name, device, flash=flash, **pipeline_kwargs) if pipeline is None \
else pipeline
self._model_name = getattr(self._pipe.model, 'name_or_path', self._model_name)
Expand Down Expand Up @@ -154,6 +156,45 @@ def _inner_transcribe(
language = 'en'
if not language and result and 'language' in result[0]:
language = result[0]['language']
if not language and hasattr(output, 'get') and 'detected_language' in output:
language = output['detected_language']
if not language:
# HF Pipelines have broken language detection.
# Manually detect language by generating tokens from the first 10 seconds of the audio.
try:
import torch
sample_audio = audio[:int(self.sampling_rate * 10)] # Use first 10 seconds
inputs = self._pipe.feature_extractor(sample_audio, sampling_rate=self.sampling_rate, return_tensors="pt")

# Ensure input features match model dtype and device
model_dtype = next(self._pipe.model.parameters()).dtype
model_device = next(self._pipe.model.parameters()).device
inputs.input_features = inputs.input_features.to(dtype=model_dtype, device=model_device)

# Generate with minimal tokens to detect language
with torch.no_grad():
generated_ids = self._pipe.model.generate(
inputs.input_features,
max_new_tokens=10,
do_sample=False,
output_scores=True,
return_dict_in_generate=True
)

# Decode the tokens to extract language information
tokens = self._pipe.tokenizer.batch_decode(generated_ids.sequences, skip_special_tokens=False)[0]

# Extract language token (format: <|en|>, <|fr|>, etc.)
import re
lang_match = re.search(r'<\|(\w{2})\|>', tokens)
if lang_match:
language = lang_match.group(1)
else:
language = None

except Exception as e:
print(f'Error detecting language: {e}')
language = None
if verbose is not None:
print(f'Transcription completed.')

Expand Down
Loading