|
7 | 7 | from ..non_whisper import transcribe_any |
8 | 8 | from ..utils import isolate_useful_options |
9 | 9 |
|
| 10 | +from ..alignment import align, align_words, refine |
| 11 | + |
10 | 12 |
|
11 | 13 | HF_MODELS = { |
12 | 14 | "tiny.en": "openai/whisper-tiny.en", |
|
25 | 27 | "turbo": "openai/whisper-large-v3-turbo" |
26 | 28 | } |
27 | 29 |
|
| 30 | +WHISPER_TO_HF_MAPPING = { |
| 31 | + "blocks": "layers", |
| 32 | + "mlp.0": "fc1", |
| 33 | + "mlp.2": "fc2", |
| 34 | + "mlp_ln": "final_layer_norm", |
| 35 | + ".attn.query": ".self_attn.q_proj", |
| 36 | + ".attn.key": ".self_attn.k_proj", |
| 37 | + ".attn.value": ".self_attn.v_proj", |
| 38 | + ".attn_ln": ".self_attn_layer_norm", |
| 39 | + ".attn.out": ".self_attn.out_proj", |
| 40 | + ".cross_attn.query": ".encoder_attn.q_proj", |
| 41 | + ".cross_attn.key": ".encoder_attn.k_proj", |
| 42 | + ".cross_attn.value": ".encoder_attn.v_proj", |
| 43 | + ".cross_attn_ln": ".encoder_attn_layer_norm", |
| 44 | + ".cross_attn.out": ".encoder_attn.out_proj", |
| 45 | + "decoder.ln.": "decoder.layer_norm.", |
| 46 | + "encoder.ln.": "encoder.layer_norm.", |
| 47 | + "token_embedding": "embed_tokens", |
| 48 | + "encoder.positional_embedding": "encoder.embed_positions.weight", |
| 49 | + "decoder.positional_embedding": "decoder.embed_positions.weight", |
| 50 | + "ln_post": "layer_norm", |
| 51 | +} |
| 52 | + |
28 | 53 |
|
29 | 54 | def get_device(device: str = None) -> str: |
30 | 55 | if device: |
@@ -81,6 +106,7 @@ def __init__(self, model_name: str, device: str = None, flash: bool = False, pip |
81 | 106 | self._pipe = load_hf_pipe(self._model_name, device, flash=flash, **pipeline_kwargs) if pipeline is None \ |
82 | 107 | else pipeline |
83 | 108 | self._model_name = getattr(self._pipe.model, 'name_or_path', self._model_name) |
| 109 | + self._vanilla_model = None |
84 | 110 |
|
85 | 111 | @property |
86 | 112 | def sampling_rate(self): |
@@ -263,6 +289,70 @@ def transcribe( |
263 | 289 | **transcribe_any_options |
264 | 290 | ) |
265 | 291 |
|
| 292 | + def as_vanilla_model(self): |
| 293 | + """ |
| 294 | + Return a vanilla Whisper model instance with current weights. |
| 295 | +
|
| 296 | + The new instance is only loaded once. Most weights share the same memory as this Hugging Face model instance. |
| 297 | + """ |
| 298 | + if self._vanilla_model is not None: |
| 299 | + return self._vanilla_model |
| 300 | + |
| 301 | + from ..whisper_compatibility import ModelDimensions, Whisper, ln_to_fp32 |
| 302 | + from .original_whisper import modify_model |
| 303 | + try: |
| 304 | + from transformers.models.whisper.convert_openai_to_hf import WHISPER_MAPPING |
| 305 | + whisper2hf_mapping = WHISPER_MAPPING |
| 306 | + except (ImportError, ModuleNotFoundError): |
| 307 | + whisper2hf_mapping = WHISPER_TO_HF_MAPPING |
| 308 | + |
| 309 | + hf_mapping = {v: k for k, v in whisper2hf_mapping.items()} |
| 310 | + assert len(whisper2hf_mapping) == len(hf_mapping) |
| 311 | + |
| 312 | + state_dict = self._pipe.model.model.state_dict() |
| 313 | + config = self._pipe.model.config |
| 314 | + |
| 315 | + if 'encoder.layer_norm.' in hf_mapping: |
| 316 | + hf_mapping['encoder.layer_norm.'] = 'encoder.ln_post.' |
| 317 | + for key in list(state_dict.keys()): |
| 318 | + new_key = key |
| 319 | + for k, v in hf_mapping.items(): |
| 320 | + if k in key: |
| 321 | + new_key = new_key.replace(k, v) |
| 322 | + if new_key != key: |
| 323 | + state_dict[new_key] = state_dict.pop(key) |
| 324 | + |
| 325 | + dims = ModelDimensions( |
| 326 | + n_mels=config.num_mel_bins, |
| 327 | + n_audio_ctx=config.max_source_positions, |
| 328 | + n_audio_state=config.d_model, |
| 329 | + n_audio_head=config.encoder_attention_heads, |
| 330 | + n_audio_layer=config.encoder_layers, |
| 331 | + n_vocab=config.vocab_size, |
| 332 | + n_text_ctx=config.max_target_positions, |
| 333 | + n_text_state=self._pipe.model.model.decoder.embed_positions.embedding_dim, |
| 334 | + n_text_head=config.decoder_attention_heads, |
| 335 | + n_text_layer=config.decoder_layers |
| 336 | + ) |
| 337 | + new_model = Whisper(dims) |
| 338 | + if alignment_heads := getattr(self._pipe.model.generation_config, 'alignment_heads', None): |
| 339 | + alignment_heads = torch.as_tensor(alignment_heads).T |
| 340 | + final_heads = torch.zeros(new_model.dims.n_text_layer, new_model.dims.n_text_head, dtype=torch.bool) |
| 341 | + final_heads[alignment_heads[0], alignment_heads[1]] = True |
| 342 | + new_model.register_buffer("alignment_heads", final_heads.to_sparse(), persistent=False) |
| 343 | + else: |
| 344 | + setattr(new_model, 'missing_alignment_heads', True) |
| 345 | + new_model.load_state_dict(state_dict, strict=True, assign=True) |
| 346 | + new_model.to(device=self._pipe.model.device) |
| 347 | + ln_to_fp32(new_model) |
| 348 | + modify_model(new_model) |
| 349 | + self._vanilla_model = new_model |
| 350 | + return self._vanilla_model |
| 351 | + |
| 352 | + align = align |
| 353 | + align_words = align_words |
| 354 | + refine = refine |
| 355 | + |
266 | 356 |
|
267 | 357 | def load_hf_whisper(model_name: str, device: str = None, flash: bool = False, pipeline=None, **pipeline_kwargs): |
268 | 358 | return WhisperHF(model_name, device, flash=flash, pipeline=pipeline, **pipeline_kwargs) |
0 commit comments