Skip to content

Commit 7c0428c

Browse files
committed
fix bug
1 parent 4849a38 commit 7c0428c

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

stable_whisper/whisper_word_level/mlx_whisper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from functools import lru_cache
55

66
import numpy as np
7-
import mlx.core as mx
87

98
from ..audio import convert_demucs_kwargs, prep_audio
109
from ..non_whisper import transcribe_any
@@ -62,8 +61,6 @@ def load_mlx_model(model_name: str, dtype=None):
6261
from mlx_whisper import load_models
6362

6463
model_id = MLX_MODELS.get(model_name, model_name)
65-
if dtype is None:
66-
dtype = mx.float32
6764

6865
return load_models.load_model(model_id, dtype=dtype)
6966

@@ -186,6 +183,7 @@ def _inner_transcribe(
186183
print(f'Transcribing with MLX Whisper ({model_path})...')
187184

188185
if isinstance(audio, np.ndarray):
186+
import mlx.core as mx
189187
audio_mx = mx.array(audio)
190188
else:
191189
audio_mx = audio
@@ -316,6 +314,7 @@ def transcribe(
316314

317315

318316
def load_mlx_whisper(model_name: str, dtype=None, **model_kwargs):
317+
import mlx.core as mx
319318
if dtype is None:
320319
dtype = mx.float32
321320

0 commit comments

Comments
 (0)