Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from typing import List, Optional, Union
from dataclasses import replace
from typing import List, Optional, Union

import ctranslate2
import faster_whisper
Expand All @@ -13,7 +13,7 @@

from whisperx.audio import N_SAMPLES, SAMPLE_RATE, load_audio, log_mel_spectrogram
from whisperx.types import SingleSegment, TranscriptionResult
from whisperx.vads import Vad, Silero, Pyannote
from whisperx.vads import Pyannote, Silero, Vad


def find_numeral_symbol_tokens(tokenizer):
Expand Down Expand Up @@ -313,6 +313,7 @@ def load_model(
download_root: Optional[str] = None,
local_files_only=False,
threads=4,
use_auth_token: Optional[Union[str, bool]] = None,
) -> FasterWhisperPipeline:
"""Load a Whisper model for inference.
Args:
Expand All @@ -326,6 +327,7 @@ def load_model(
download_root - The root directory to download the model to.
local_files_only - If `True`, avoid downloading the file and return the path to the local cached file if it exists.
threads - The number of cpu threads to use per worker, e.g. will be multiplied by num workers.
use_auth_token - HuggingFace authentication token or True to use the token stored by the HuggingFace config folder.
Returns:
A Whisper pipeline.
"""
Expand All @@ -339,7 +341,9 @@ def load_model(
compute_type=compute_type,
download_root=download_root,
local_files_only=local_files_only,
cpu_threads=threads)
cpu_threads=threads,
use_auth_token=use_auth_token
)
if language is not None:
tokenizer = Tokenizer(model.hf_tokenizer, model.model.is_multilingual, task=task, language=language)
else:
Expand Down