@@ -262,7 +262,7 @@ def load_model(whisper_arch,
262262 compute_type = "float16" ,
263263 asr_options = None ,
264264 language : Optional [str ] = None ,
265- vad_model = None ,
265+ vad_model_fp = None ,
266266 vad_options = None ,
267267 model : Optional [WhisperModel ] = None ,
268268 task = "transcribe" ,
@@ -275,6 +275,7 @@ def load_model(whisper_arch,
275275 compute_type: str - The compute type to use for the model.
276276 options: dict - A dictionary of options to use for the model.
277277 language: str - The language of the model. (use English for now)
278+ vad_model_fp: str - File path to the VAD model to use
278279 model: Optional[WhisperModel] - The WhisperModel instance to use.
279280 download_root: Optional[str] - The root directory to download the model to.
280281 threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
@@ -341,8 +342,8 @@ def load_model(whisper_arch,
341342 if vad_options is not None :
342343 default_vad_options .update (vad_options )
343344
344- if vad_model is not None :
345- vad_model = vad_model
345+ if vad_model_fp is not None :
346+ vad_model = load_vad_model ( torch . device ( device ), use_auth_token = None , ** default_vad_options , model_fp = vad_model_fp )
346347 else :
347348 vad_model = load_vad_model (torch .device (device ), use_auth_token = None , ** default_vad_options )
348349
0 commit comments