@@ -173,6 +173,7 @@ def align(
173173 """
174174 model = as_vanilla (model )
175175 is_faster_model = model .__module__ .startswith ('faster_whisper.' )
176+ is_mlx_model = model .__module__ .startswith ('mlx_whisper.' )
176177 if not is_faster_model :
177178 warn_compatibility_issues (whisper , ignore_compatibility )
178179 max_token_step = (model .max_length if is_faster_model else model .dims .n_text_ctx ) - 6
@@ -185,7 +186,7 @@ def align(
185186
186187 options = AllOptions (options , vanilla_align = not is_faster_model )
187188 split_words_by_space = getattr (tokenizer , 'language_code' , tokenizer .language ) not in {"zh" , "ja" , "th" , "lo" , "my" }
188- model_type = 'fw' if is_faster_model else None
189+ model_type = 'fw' if is_faster_model else 'mlx' if is_mlx_model else None
189190 inference_func = get_whisper_alignment_func (model , tokenizer , model_type , options )
190191
191192 aligner = Aligner (
@@ -336,14 +337,15 @@ def align_words(
336337 """
337338 model = as_vanilla (model )
338339 is_faster_model = model .__module__ .startswith ('faster_whisper.' )
340+ is_mlx_model = model .__module__ .startswith ('mlx_whisper.' )
339341 if not is_faster_model :
340342 warn_compatibility_issues (whisper , ignore_compatibility )
341343 tokenizer , supported_languages = get_alignment_tokenizer (model , is_faster_model , result , language , tokenizer )
342344
343345 options = AllOptions (options )
344346 split_words_by_space = getattr (tokenizer , 'language_code' , tokenizer .language ) not in {"zh" , "ja" , "th" , "lo" , "my" }
345347 max_segment_tokens = model .max_length if is_faster_model else model .dims .n_text_ctx
346- inference_func = get_whisper_alignment_func (model , tokenizer , 'fw' if is_faster_model else None , options )
348+ inference_func = get_whisper_alignment_func (model , tokenizer , 'fw' if is_faster_model else 'mlx' if is_mlx_model else None , options )
347349
348350 aligner = Aligner (
349351 inference_func = inference_func ,
@@ -393,7 +395,7 @@ def get_whisper_alignment_func(
393395 model_type : Optional [str ] = None ,
394396 options : Optional [AllOptions ] = None
395397):
396- assert model_type in (None , 'fw' )
398+ assert model_type in (None , 'fw' , 'mlx' )
397399
398400 if model_type is None :
399401 def compute_timestamps (audio_segment : torch .Tensor , word_tokens : List [WordToken ]) -> List [dict ]:
@@ -421,6 +423,53 @@ def compute_timestamps(audio_segment: torch.Tensor, word_tokens: List[WordToken]
421423 )
422424 return [w for seg in temp_segments for w in seg ['words' ]]
423425
426+ elif model_type == 'mlx' :
427+ from mlx_whisper .audio import (
428+ N_FRAMES as MLX_N_FRAMES ,
429+ SAMPLE_RATE as MLX_SAMPLE_RATE ,
430+ log_mel_spectrogram as log_mel_spectrogram_mx ,
431+ pad_or_trim as pad_or_trim_mx
432+ )
433+ import mlx .core as mx
434+ import mlx_whisper .timing as timing
435+
436+ def compute_timestamps (audio_segment_torch : torch .Tensor , word_tokens : List [WordToken ]) -> List [dict ]:
437+ audio_segment_np = audio_segment_torch .squeeze ().numpy ().astype ('float32' )
438+ audio_segment_mx = mx .array (audio_segment_np )
439+
440+ segment_samples = audio_segment_mx .shape [- 1 ]
441+
442+ temp_segment = dict (
443+ seek = 0 ,
444+ start = 0.0 ,
445+ end = round (segment_samples / MLX_SAMPLE_RATE , 3 ),
446+ tokens = [t for wt in word_tokens for t in wt .tokens ],
447+ words = []
448+ )
449+
450+ mel_segments_raw = log_mel_spectrogram_mx (
451+ audio = np .array (audio_segment_mx ),
452+ n_mels = model .dims .n_mels ,
453+ padding = 0
454+ )
455+
456+ num_frames_unpadded = mel_segments_raw .shape [0 ]
457+
458+ mel_segments_padded_time = pad_or_trim_mx (mel_segments_raw .T , MLX_N_FRAMES )
459+
460+ mel_segments_nlc = mel_segments_padded_time .T
461+
462+ timing .add_word_timestamps (
463+ segments = [temp_segment ],
464+ model = model ,
465+ tokenizer = tokenizer ,
466+ mel = mel_segments_nlc ,
467+ num_frames = num_frames_unpadded ,
468+ last_speech_timestamp = 0.0
469+ )
470+
471+ return temp_segment .get ('words' , [])
472+
424473 else :
425474 from .whisper_compatibility import is_faster_whisper_on_pt
426475 from faster_whisper .version import __version__ as fw_ver
@@ -548,6 +597,7 @@ def refine(
548597 """
549598 model = as_vanilla (model )
550599 is_faster_model = model .__module__ .startswith ('faster_whisper.' )
600+ is_mlx_model = model .__module__ .startswith ('mlx_whisper' )
551601 if result and (not result .has_words or any (word .probability is None for word in result .all_words ())):
552602 if not result .language :
553603 raise RuntimeError (f'cannot align words with result missing language' )
@@ -558,7 +608,7 @@ def refine(
558608 word .tokens = tokenizer .encode (word .word )
559609
560610 options = AllOptions (options , post = False , silence = False , align = False )
561- model_type = 'fw' if is_faster_model else None
611+ model_type = 'fw' if is_faster_model else 'mlx' if is_mlx_model else None
562612 inference_func = get_whisper_refinement_func (model , tokenizer , model_type , single_batch )
563613 max_inference_tokens = (model .max_length if is_faster_model else model .dims .n_text_ctx ) - 6
564614
@@ -588,7 +638,7 @@ def get_whisper_refinement_func(
588638 model_type : Optional [str ] = None ,
589639 single_batch : bool = False
590640):
591- assert model_type in (None , 'fw' )
641+ assert model_type in (None , 'fw' , 'mlx' )
592642
593643 if model_type is None :
594644 def inference_func (audio_segment : torch , tokens : List [int ]) -> torch .Tensor :
@@ -616,6 +666,56 @@ def inference_func(audio_segment: torch, tokens: List[int]) -> torch.Tensor:
616666 token_probs = sampled_logits .softmax (dim = - 1 )
617667 return token_probs
618668
669+ elif model_type == 'mlx' :
670+ from mlx_whisper .audio import (
671+ N_FRAMES_MLX ,
672+ log_mel_spectrogram as log_mel_spectrogram_mx ,
673+ pad_or_trim as pad_or_trim_mx
674+ )
675+ import mlx .core as mx
676+
677+ def inference_func (audio_batch_torch : torch .Tensor , tokens : List [int ]) -> torch .Tensor :
678+ input_tokens_mx = mx .array (
679+ [
680+ * tokenizer .sot_sequence ,
681+ tokenizer .no_timestamps ,
682+ * tokens ,
683+ tokenizer .eot ,
684+ ]
685+ )
686+
687+ audio_batch_np = audio_batch_torch .numpy ().astype ('float32' )
688+ audio_batch_mx = mx .array (audio_batch_np )
689+
690+ mel_list = []
691+ for audio_segment_mx in audio_batch_mx :
692+ mel_raw = log_mel_spectrogram_mx (audio = np .array (audio_segment_mx ), n_mels = model .dims .n_mels , padding = 0 )
693+
694+ mel_transposed = mel_raw .T
695+
696+ mel_padded_cl = pad_or_trim_mx (mel_transposed , N_FRAMES_MLX )
697+
698+ mel_nlc = mel_padded_cl .T
699+ mel_list .append (mel_nlc )
700+
701+ mel_batch_nlc = mx .stack (mel_list , axis = 0 )
702+
703+ logits_list = []
704+ for single_mel_nlc in mel_batch_nlc :
705+ logits_single = model (single_mel_nlc [None ], input_tokens_mx [None ])
706+ logits_list .append (logits_single )
707+
708+ logits = mx .concatenate (logits_list , axis = 0 )
709+
710+ sot_len = len (tokenizer .sot_sequence )
711+ start_idx = sot_len + 1
712+ end_idx = start_idx + len (tokens )
713+
714+ sampled_logits = logits [:, start_idx :end_idx , :tokenizer .eot ]
715+ token_probs = mx .softmax (sampled_logits , axis = - 1 )
716+
717+ token_probs_np = np .array (token_probs , copy = True )
718+ return torch .from_numpy (token_probs_np )
619719 else :
620720 from .whisper_compatibility import is_faster_whisper_on_pt
621721 from faster_whisper .version import __version__ as fw_ver
0 commit comments