Skip to content

Commit 751b041

Browse files
committed
added alignment and refinement support for HF models
-added full alignment and refinement support for Hugging Face models
1 parent fefaf46 commit 751b041

File tree

5 files changed

+125
-6
lines changed

5 files changed

+125
-6
lines changed

README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,17 +504,22 @@ Docstring:
504504

505505

506506
<details>
507-
<summary>Hugging Face Transformers (~9x faster)</summary>
507+
<summary>Hugging Face Transformers</summary>
508508

509-
Run Whisper up to 9x faster with [Hugging Face Transformer](https://huggingface.co/openai/whisper-large-v3):
509+
Transcribe up to 9x faster with [Hugging Face Transformer](https://huggingface.co/openai/whisper-large-v3):
510510
```
511511
pip install -U stable-ts[hf]
512512
```
513-
* [Alignment](#alignment) and [Refinement](#refinement) are not supported on Hugging Face models
513+
514514
```python
515515
model = stable_whisper.load_hf_whisper('base')
516516
result = model.transcribe('audio.mp3')
517517
```
518+
Supports the [various versions on Hugging Face](https://huggingface.co/models?other=whisper&sort=downloads):
519+
```python
520+
model = stable_whisper.load_hf_whisper('openai/whisper-base.en')
521+
```
522+
518523

519524
<details>
520525
<summary>CLI</summary>

stable_whisper/alignment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from .whisper_compatibility import (
1616
SAMPLE_RATE, N_FRAMES, N_FFT, pad_or_trim, log_mel_spectrogram, FRAMES_PER_SECOND, CHUNK_LENGTH, N_SAMPLES,
17-
median_filter, DecodingTask, DecodingOptions, SuppressTokens, whisper, TOKENS_PER_SECOND
17+
median_filter, DecodingTask, DecodingOptions, SuppressTokens, whisper, TOKENS_PER_SECOND, as_vanilla
1818
)
1919

2020
if TYPE_CHECKING:
@@ -171,6 +171,7 @@ def align(
171171
>>> result.to_srt_vtt('helloword.srt')
172172
Saved 'helloworld.srt'
173173
"""
174+
model = as_vanilla(model)
174175
is_faster_model = model.__module__.startswith('faster_whisper.')
175176
if not is_faster_model:
176177
warn_compatibility_issues(whisper, ignore_compatibility)
@@ -333,6 +334,7 @@ def align_words(
333334
>>> result = [dict(start=0.0, end=0.5, text='hello world 1'), dict(start=0.5, end=1.0, text='hello world 2')]
334335
>>> result = model.align_words('audio.mp3', result, 'English')
335336
"""
337+
model = as_vanilla(model)
336338
is_faster_model = model.__module__.startswith('faster_whisper.')
337339
if not is_faster_model:
338340
warn_compatibility_issues(whisper, ignore_compatibility)
@@ -544,6 +546,7 @@ def refine(
544546
>>> result.to_srt_vtt('audio.srt')
545547
Saved 'audio.srt'
546548
"""
549+
model = as_vanilla(model)
547550
if result:
548551
if not result.has_words:
549552
if not result.language:

stable_whisper/timing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def _compute_atten_weights(
8282
if cache['qks'] is None:
8383
_compute_qks(model, tokenizer, text_tokens, mel, tokens, cache)
8484
QKs = cache['qks']
85+
if getattr(model, 'missing_alignment_heads', False) and not dynamic_heads_count:
86+
dynamic_heads_count = 6
8587
if dynamic_heads_count:
8688
max_qk_len = round(num_samples / N_SAMPLES_PER_TOKEN)
8789
if not cache.get('is_processed_qks'):

stable_whisper/whisper_compatibility.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _dummy_contextmanager():
6767
from whisper.tokenizer import get_tokenizer as get_whisper_tokenizer
6868

6969
from whisper.tokenizer import Tokenizer
70-
from whisper.model import Whisper
70+
from whisper.model import Whisper, ModelDimensions, LayerNorm
7171
from whisper.decoding import DecodingTask, DecodingOptions, DecodingResult, SuppressTokens
7272
try:
7373
from whisper.model import disable_sdpa
@@ -90,7 +90,9 @@ def _dummy_contextmanager():
9090

9191
log_mel_spectrogram = median_filter = dtw = merge_punctuations = get_whisper_tokenizer \
9292
= whisper_not_available
93-
Tokenizer = Whisper = DecodingTask = DecodingOptions = DecodingResult = SuppressTokens = Unavailable
93+
Tokenizer = Whisper = ModelDimensions = LayerNorm = \
94+
DecodingTask = DecodingOptions = DecodingResult = SuppressTokens \
95+
= Unavailable
9496
LANGUAGES = {
9597
"en": "english",
9698
"zh": "chinese",
@@ -330,3 +332,20 @@ def get_tokenizer(model=None, is_faster_model: bool = False, **kwargs):
330332
del kwargs['num_languages']
331333
kwargs['language'] = get_valid_language(kwargs.get('language'), is_faster_model, model)
332334
return tokenizer(**kwargs)
335+
336+
337+
def as_vanilla(model):
338+
return model.as_vanilla_model() if hasattr(model, 'as_vanilla_model') else model
339+
340+
341+
def ln_to_fp32(module):
342+
"""
343+
Convert all parameters in LayerNorm of model to float32.
344+
"""
345+
for child in module.children():
346+
if isinstance(child, LayerNorm):
347+
child.weight.data = child.weight.data.float()
348+
if child.bias is not None:
349+
child.bias.data = child.bias.data.float()
350+
else:
351+
ln_to_fp32(child)

stable_whisper/whisper_word_level/hf_whisper.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from ..non_whisper import transcribe_any
88
from ..utils import isolate_useful_options
99

10+
from ..alignment import align, align_words, refine
11+
1012

1113
HF_MODELS = {
1214
"tiny.en": "openai/whisper-tiny.en",
@@ -25,6 +27,29 @@
2527
"turbo": "openai/whisper-large-v3-turbo"
2628
}
2729

30+
WHISPER_TO_HF_MAPPING = {
31+
"blocks": "layers",
32+
"mlp.0": "fc1",
33+
"mlp.2": "fc2",
34+
"mlp_ln": "final_layer_norm",
35+
".attn.query": ".self_attn.q_proj",
36+
".attn.key": ".self_attn.k_proj",
37+
".attn.value": ".self_attn.v_proj",
38+
".attn_ln": ".self_attn_layer_norm",
39+
".attn.out": ".self_attn.out_proj",
40+
".cross_attn.query": ".encoder_attn.q_proj",
41+
".cross_attn.key": ".encoder_attn.k_proj",
42+
".cross_attn.value": ".encoder_attn.v_proj",
43+
".cross_attn_ln": ".encoder_attn_layer_norm",
44+
".cross_attn.out": ".encoder_attn.out_proj",
45+
"decoder.ln.": "decoder.layer_norm.",
46+
"encoder.ln.": "encoder.layer_norm.",
47+
"token_embedding": "embed_tokens",
48+
"encoder.positional_embedding": "encoder.embed_positions.weight",
49+
"decoder.positional_embedding": "decoder.embed_positions.weight",
50+
"ln_post": "layer_norm",
51+
}
52+
2853

2954
def get_device(device: str = None) -> str:
3055
if device:
@@ -81,6 +106,7 @@ def __init__(self, model_name: str, device: str = None, flash: bool = False, pip
81106
self._pipe = load_hf_pipe(self._model_name, device, flash=flash, **pipeline_kwargs) if pipeline is None \
82107
else pipeline
83108
self._model_name = getattr(self._pipe.model, 'name_or_path', self._model_name)
109+
self._vanilla_model = None
84110

85111
@property
86112
def sampling_rate(self):
@@ -263,6 +289,70 @@ def transcribe(
263289
**transcribe_any_options
264290
)
265291

292+
def as_vanilla_model(self):
293+
"""
294+
Return a vanilla Whisper model instance with current weights.
295+
296+
The new instance is only loaded once. Most weights share the same memory as this Hugging Face model instance.
297+
"""
298+
if self._vanilla_model is not None:
299+
return self._vanilla_model
300+
301+
from ..whisper_compatibility import ModelDimensions, Whisper, ln_to_fp32
302+
from .original_whisper import modify_model
303+
try:
304+
from transformers.models.whisper.convert_openai_to_hf import WHISPER_MAPPING
305+
whisper2hf_mapping = WHISPER_MAPPING
306+
except (ImportError, ModuleNotFoundError):
307+
whisper2hf_mapping = WHISPER_TO_HF_MAPPING
308+
309+
hf_mapping = {v: k for k, v in whisper2hf_mapping.items()}
310+
assert len(whisper2hf_mapping) == len(hf_mapping)
311+
312+
state_dict = self._pipe.model.model.state_dict()
313+
config = self._pipe.model.config
314+
315+
if 'encoder.layer_norm.' in hf_mapping:
316+
hf_mapping['encoder.layer_norm.'] = 'encoder.ln_post.'
317+
for key in list(state_dict.keys()):
318+
new_key = key
319+
for k, v in hf_mapping.items():
320+
if k in key:
321+
new_key = new_key.replace(k, v)
322+
if new_key != key:
323+
state_dict[new_key] = state_dict.pop(key)
324+
325+
dims = ModelDimensions(
326+
n_mels=config.num_mel_bins,
327+
n_audio_ctx=config.max_source_positions,
328+
n_audio_state=config.d_model,
329+
n_audio_head=config.encoder_attention_heads,
330+
n_audio_layer=config.encoder_layers,
331+
n_vocab=config.vocab_size,
332+
n_text_ctx=config.max_target_positions,
333+
n_text_state=self._pipe.model.model.decoder.embed_positions.embedding_dim,
334+
n_text_head=config.decoder_attention_heads,
335+
n_text_layer=config.decoder_layers
336+
)
337+
new_model = Whisper(dims)
338+
if alignment_heads := getattr(self._pipe.model.generation_config, 'alignment_heads', None):
339+
alignment_heads = torch.as_tensor(alignment_heads).T
340+
final_heads = torch.zeros(new_model.dims.n_text_layer, new_model.dims.n_text_head, dtype=torch.bool)
341+
final_heads[alignment_heads[0], alignment_heads[1]] = True
342+
new_model.register_buffer("alignment_heads", final_heads.to_sparse(), persistent=False)
343+
else:
344+
setattr(new_model, 'missing_alignment_heads', True)
345+
new_model.load_state_dict(state_dict, strict=True, assign=True)
346+
new_model.to(device=self._pipe.model.device)
347+
ln_to_fp32(new_model)
348+
modify_model(new_model)
349+
self._vanilla_model = new_model
350+
return self._vanilla_model
351+
352+
align = align
353+
align_words = align_words
354+
refine = refine
355+
266356

267357
def load_hf_whisper(model_name: str, device: str = None, flash: bool = False, pipeline=None, **pipeline_kwargs):
268358
return WhisperHF(model_name, device, flash=flash, pipeline=pipeline, **pipeline_kwargs)

0 commit comments

Comments
 (0)