Skip to content

Commit c125472

Browse files
added support for MLX-Whisper (#442)
-added transcription, alignment, refinement support for MLX-Whisper models -added tests for MLX-Whisper models -updated to README.md to list support for MLX-Whisper models Co-authored-by: jian <[email protected]>
1 parent 6172a0c commit c125472

File tree

7 files changed

+505
-9
lines changed

7 files changed

+505
-9
lines changed

.github/workflows/test.yml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,22 @@ jobs:
6767
- run: python test/test_transcribe.py load_hf_whisper
6868
- run: python test/test_align.py load_hf_whisper
6969
- run: python test/test_refine.py load_hf_whisper
70+
71+
mlx-test:
72+
runs-on: macos-latest
73+
steps:
74+
- uses: actions/checkout@v4
75+
- name: Set up Python
76+
uses: actions/setup-python@v5
77+
with:
78+
python-version: '3.12'
79+
- name: Install FFmpeg
80+
run: brew install ffmpeg
81+
- name: Install Package with MLX dependencies
82+
run: pip3 install .["mlx"]
83+
- name: Run MLX transcribe tests
84+
run: python test/test_transcribe.py load_mlx_whisper
85+
- name: Run MLX align tests
86+
run: python test/test_align.py load_mlx_whisper
87+
- name: Run MLX refine tests
88+
run: python test/test_refine.py load_mlx_whisper

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,32 @@ stable-ts audio.mp3 -o audio.srt -hw
536536

537537
</details>
538538

539+
<details>
540+
<summary>MLX Whisper (on Apple Silicon)</summary>
541+
542+
Transcribe faster on Apple devices with [MLX Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper):
543+
```
544+
pip install -U stable-ts[mlx]
545+
```
546+
547+
```python
548+
import stable_whisper
549+
550+
model = stable_whisper.load_mlx_whisper('base')
551+
result = model.transcribe('audio.mp3')
552+
```
553+
554+
555+
<details>
556+
<summary>CLI</summary>
557+
558+
```commandline
559+
stable-ts audio.mp3 -o audio.srt -mlx
560+
```
561+
</details>
562+
563+
</details>
564+
539565
### Output
540566

541567
https://github.com/jianfch/stable-ts/assets/28970749/c22dcdf9-79cb-485a-ae38-184d006e513e

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def read_me() -> str:
3838
"transformers>=4.23.0",
3939
"optimum",
4040
"accelerate"
41+
],
42+
"mlx": [
43+
"mlx-whisper"
4144
]
4245
},
4346
entry_points={

stable_whisper/alignment.py

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def align(
173173
"""
174174
model = as_vanilla(model)
175175
is_faster_model = model.__module__.startswith('faster_whisper.')
176+
is_mlx_model = model.__module__.startswith('mlx_whisper.')
176177
if not is_faster_model:
177178
warn_compatibility_issues(whisper, ignore_compatibility)
178179
max_token_step = (model.max_length if is_faster_model else model.dims.n_text_ctx) - 6
@@ -185,7 +186,7 @@ def align(
185186

186187
options = AllOptions(options, vanilla_align=not is_faster_model)
187188
split_words_by_space = getattr(tokenizer, 'language_code', tokenizer.language) not in {"zh", "ja", "th", "lo", "my"}
188-
model_type = 'fw' if is_faster_model else None
189+
model_type = 'fw' if is_faster_model else 'mlx' if is_mlx_model else None
189190
inference_func = get_whisper_alignment_func(model, tokenizer, model_type, options)
190191

191192
aligner = Aligner(
@@ -336,14 +337,15 @@ def align_words(
336337
"""
337338
model = as_vanilla(model)
338339
is_faster_model = model.__module__.startswith('faster_whisper.')
340+
is_mlx_model = model.__module__.startswith('mlx_whisper.')
339341
if not is_faster_model:
340342
warn_compatibility_issues(whisper, ignore_compatibility)
341343
tokenizer, supported_languages = get_alignment_tokenizer(model, is_faster_model, result, language, tokenizer)
342344

343345
options = AllOptions(options)
344346
split_words_by_space = getattr(tokenizer, 'language_code', tokenizer.language) not in {"zh", "ja", "th", "lo", "my"}
345347
max_segment_tokens = model.max_length if is_faster_model else model.dims.n_text_ctx
346-
inference_func = get_whisper_alignment_func(model, tokenizer, 'fw' if is_faster_model else None, options)
348+
inference_func = get_whisper_alignment_func(model, tokenizer, 'fw' if is_faster_model else 'mlx' if is_mlx_model else None, options)
347349

348350
aligner = Aligner(
349351
inference_func=inference_func,
@@ -393,7 +395,7 @@ def get_whisper_alignment_func(
393395
model_type: Optional[str] = None,
394396
options: Optional[AllOptions] = None
395397
):
396-
assert model_type in (None, 'fw')
398+
assert model_type in (None, 'fw', 'mlx')
397399

398400
if model_type is None:
399401
def compute_timestamps(audio_segment: torch.Tensor, word_tokens: List[WordToken]) -> List[dict]:
@@ -421,6 +423,53 @@ def compute_timestamps(audio_segment: torch.Tensor, word_tokens: List[WordToken]
421423
)
422424
return [w for seg in temp_segments for w in seg['words']]
423425

426+
elif model_type == 'mlx':
427+
from mlx_whisper.audio import (
428+
N_FRAMES as MLX_N_FRAMES,
429+
SAMPLE_RATE as MLX_SAMPLE_RATE,
430+
log_mel_spectrogram as log_mel_spectrogram_mx,
431+
pad_or_trim as pad_or_trim_mx
432+
)
433+
import mlx.core as mx
434+
import mlx_whisper.timing as timing
435+
436+
def compute_timestamps(audio_segment_torch: torch.Tensor, word_tokens: List[WordToken]) -> List[dict]:
437+
audio_segment_np = audio_segment_torch.squeeze().numpy().astype('float32')
438+
audio_segment_mx = mx.array(audio_segment_np)
439+
440+
segment_samples = audio_segment_mx.shape[-1]
441+
442+
temp_segment = dict(
443+
seek=0,
444+
start=0.0,
445+
end=round(segment_samples / MLX_SAMPLE_RATE, 3),
446+
tokens=[t for wt in word_tokens for t in wt.tokens],
447+
words=[]
448+
)
449+
450+
mel_segments_raw = log_mel_spectrogram_mx(
451+
audio=np.array(audio_segment_mx),
452+
n_mels=model.dims.n_mels,
453+
padding=0
454+
)
455+
456+
num_frames_unpadded = mel_segments_raw.shape[0]
457+
458+
mel_segments_padded_time = pad_or_trim_mx(mel_segments_raw.T, MLX_N_FRAMES)
459+
460+
mel_segments_nlc = mel_segments_padded_time.T
461+
462+
timing.add_word_timestamps(
463+
segments=[temp_segment],
464+
model=model,
465+
tokenizer=tokenizer,
466+
mel=mel_segments_nlc,
467+
num_frames=num_frames_unpadded,
468+
last_speech_timestamp=0.0
469+
)
470+
471+
return temp_segment.get('words', [])
472+
424473
else:
425474
from .whisper_compatibility import is_faster_whisper_on_pt
426475
from faster_whisper.version import __version__ as fw_ver
@@ -548,6 +597,7 @@ def refine(
548597
"""
549598
model = as_vanilla(model)
550599
is_faster_model = model.__module__.startswith('faster_whisper.')
600+
is_mlx_model = model.__module__.startswith('mlx_whisper')
551601
if result and (not result.has_words or any(word.probability is None for word in result.all_words())):
552602
if not result.language:
553603
raise RuntimeError(f'cannot align words with result missing language')
@@ -558,7 +608,7 @@ def refine(
558608
word.tokens = tokenizer.encode(word.word)
559609

560610
options = AllOptions(options, post=False, silence=False, align=False)
561-
model_type = 'fw' if is_faster_model else None
611+
model_type = 'fw' if is_faster_model else 'mlx' if is_mlx_model else None
562612
inference_func = get_whisper_refinement_func(model, tokenizer, model_type, single_batch)
563613
max_inference_tokens = (model.max_length if is_faster_model else model.dims.n_text_ctx) - 6
564614

@@ -588,7 +638,7 @@ def get_whisper_refinement_func(
588638
model_type: Optional[str] = None,
589639
single_batch: bool = False
590640
):
591-
assert model_type in (None, 'fw')
641+
assert model_type in (None, 'fw', 'mlx')
592642

593643
if model_type is None:
594644
def inference_func(audio_segment: torch, tokens: List[int]) -> torch.Tensor:
@@ -616,6 +666,56 @@ def inference_func(audio_segment: torch, tokens: List[int]) -> torch.Tensor:
616666
token_probs = sampled_logits.softmax(dim=-1)
617667
return token_probs
618668

669+
elif model_type == 'mlx':
670+
from mlx_whisper.audio import (
671+
N_FRAMES_MLX,
672+
log_mel_spectrogram as log_mel_spectrogram_mx,
673+
pad_or_trim as pad_or_trim_mx
674+
)
675+
import mlx.core as mx
676+
677+
def inference_func(audio_batch_torch: torch.Tensor, tokens: List[int]) -> torch.Tensor:
678+
input_tokens_mx = mx.array(
679+
[
680+
*tokenizer.sot_sequence,
681+
tokenizer.no_timestamps,
682+
*tokens,
683+
tokenizer.eot,
684+
]
685+
)
686+
687+
audio_batch_np = audio_batch_torch.numpy().astype('float32')
688+
audio_batch_mx = mx.array(audio_batch_np)
689+
690+
mel_list = []
691+
for audio_segment_mx in audio_batch_mx:
692+
mel_raw = log_mel_spectrogram_mx(audio=np.array(audio_segment_mx), n_mels=model.dims.n_mels, padding=0)
693+
694+
mel_transposed = mel_raw.T
695+
696+
mel_padded_cl = pad_or_trim_mx(mel_transposed, N_FRAMES_MLX)
697+
698+
mel_nlc = mel_padded_cl.T
699+
mel_list.append(mel_nlc)
700+
701+
mel_batch_nlc = mx.stack(mel_list, axis=0)
702+
703+
logits_list = []
704+
for single_mel_nlc in mel_batch_nlc:
705+
logits_single = model(single_mel_nlc[None], input_tokens_mx[None])
706+
logits_list.append(logits_single)
707+
708+
logits = mx.concatenate(logits_list, axis=0)
709+
710+
sot_len = len(tokenizer.sot_sequence)
711+
start_idx = sot_len + 1
712+
end_idx = start_idx + len(tokens)
713+
714+
sampled_logits = logits[:, start_idx:end_idx, :tokenizer.eot]
715+
token_probs = mx.softmax(sampled_logits, axis=-1)
716+
717+
token_probs_np = np.array(token_probs, copy=True)
718+
return torch.from_numpy(token_probs_np)
619719
else:
620720
from .whisper_compatibility import is_faster_whisper_on_pt
621721
from faster_whisper.version import __version__ as fw_ver

stable_whisper/whisper_word_level/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from .original_whisper import transcribe_stable, transcribe_minimal, load_model, modify_model
44
from .faster_whisper import load_faster_whisper
55
from .hf_whisper import load_hf_whisper
6+
from .mlx_whisper import load_mlx_whisper
67

78

8-
__all__ = ['load_model', 'modify_model', 'load_faster_whisper', 'load_hf_whisper']
9+
__all__ = ['load_model', 'modify_model', 'load_faster_whisper', 'load_hf_whisper', 'load_mlx_whisper',]
910

1011
warnings.filterwarnings('ignore', module='whisper', message='.*Triton.*', category=UserWarning)
1112

stable_whisper/whisper_word_level/cli.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def valid_model_name(name: str) -> str:
5252
from whisper import available_models
5353
elif is_faster_whisper:
5454
from faster_whisper.utils import available_models
55-
_models = None if is_hf_whisper or available_models is None else available_models()
55+
_models = None if is_hf_whisper or is_mlx_whisper or available_models is None else available_models()
5656

5757
if not _models or name in _models or os.path.exists(name):
5858
return name
@@ -373,6 +373,10 @@ def url_to_path(url: str):
373373
' and even more speed with Flash Attention enabled on supported GPUs'
374374
'(https://huggingface.co/openai/whisper-large-v3); '
375375
'note: some features may not be available')
376+
parser.add_argument('--mlx_whisper', '-mlx', action='store_true',
377+
help='whether to use mlx-whisper '
378+
'(https://github.com/ml-explore/mlx-examples/tree/main/whisper); '
379+
'note: some features may not be available')
376380

377381
parser.add_argument('--persist', '-p', action='store_true',
378382
help='Keep previous model loaded for the future sets of commands in the same CLI instance')
@@ -403,12 +407,14 @@ def url_to_path(url: str):
403407
if '--model' not in args and '-m' not in args:
404408
args.extend(['-m', _cache['model']['name']])
405409
model_type = _cache['model']['type']
406-
type_arg = '--faster_whisper' in args or '-fw' in args or '--huggingface_whisper' in args or '-hw' in args
410+
type_arg = '--faster_whisper' in args or '-fw' in args or '--huggingface_whisper' in args or '-hw' in args or '--mlx_whisper' in args or '-mlx' in args
407411
if not type_arg:
408412
if model_type == 'Faster-Whisper':
409413
args.append('-fw')
410414
elif model_type == 'Hugging Face Whisper':
411415
args.append('-hw')
416+
elif model_type == 'MLX Whisper':
417+
args.append('-mlx')
412418

413419
_, invalid_args = parser.parse_known_args(args)
414420
if invalid_args:
@@ -423,9 +429,12 @@ def url_to_path(url: str):
423429
raise ValueError('langauge is required for --align / --locate')
424430

425431
is_faster_whisper = args.pop('faster_whisper')
432+
is_mlx_whisper = args.pop('mlx_whisper')
426433
is_hf_whisper = args.pop('huggingface_whisper')
427434
assert not (is_faster_whisper and is_hf_whisper), f'--huggingface_whisper cannot be used with --faster_whisper'
428-
is_original_whisper = not (is_faster_whisper or is_hf_whisper)
435+
assert not (is_faster_whisper and is_mlx_whisper), f'--mlx_whisper cannot be used with --faster_whisper'
436+
assert not (is_hf_whisper and is_mlx_whisper), f'--mlx_whisper cannot be used with --huggingface_whisper'
437+
is_original_whisper = not (is_faster_whisper or is_hf_whisper or is_mlx_whisper)
429438
args['language'] = valid_language(args['language'])
430439
model_name: str = valid_model_name(args.pop("model"))
431440
model_dir: str = args.pop("model_dir")
@@ -463,6 +472,10 @@ def url_to_path(url: str):
463472
model_type_name = 'Faster-Whisper'
464473
from .faster_whisper import load_faster_whisper as load_model_func
465474
model_name_kwarg = dict(model_size_or_path=model_name)
475+
elif is_mlx_whisper:
476+
model_type_name = 'MLX Whisper'
477+
from .mlx_whisper import load_mlx_whisper as load_model_func
478+
model_name_kwarg = dict(model_name=model_name)
466479
else:
467480
model_type_name = 'Hugging Face Whisper'
468481
from .hf_whisper import load_hf_whisper as load_model_func

0 commit comments

Comments
 (0)