diff --git a/whisperx/alignment.py b/whisperx/alignment.py index 756d0ba7..dcd972f0 100644 --- a/whisperx/alignment.py +++ b/whisperx/alignment.py @@ -5,7 +5,7 @@ import math from dataclasses import dataclass -from typing import Iterable, Optional, Union, List +from typing import Optional, Union, List import numpy as np import pandas as pd @@ -115,7 +115,7 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str] def align( - transcript: Iterable[SingleSegment], + transcript: List[SingleSegment], model: torch.nn.Module, align_model_metadata: dict, audio: Union[str, np.ndarray, torch.Tensor], diff --git a/whisperx/asr.py b/whisperx/asr.py index c35900cf..2053cfaf 100644 --- a/whisperx/asr.py +++ b/whisperx/asr.py @@ -1,6 +1,6 @@ import os -from typing import List, Optional, Union -from dataclasses import replace +from typing import List, Optional, Union, Tuple +from dataclasses import dataclass, replace, field import ctranslate2 import faster_whisper @@ -28,6 +28,78 @@ def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens.append(i) return numeral_symbol_tokens + +@dataclass +class DefaultASRTranscriptionOptions(TranscriptionOptions): + """Default configuration options for Automatic Speech Recognition (ASR) transcription using Whisper models. + + Extends `faster_whisper.transcribe.TranscriptionOptions` with documented defaults that balance + accuracy, speed, and output quality. Override these via the `asr_options` parameter in `load_model`. + + Fields: + beam_size: Number of beams in beam search. Higher values increase accuracy but slow down inference. + best_of: Number of candidates when sampling multiple temperatures. Higher values improve quality + but increase computation. + patience: Beam search patience factor (early stopping). Values >1 allow longer sequences. + length_penalty: Exponential penalty to length (alpha in Google NMT). Adjusts for sequence length bias. + repetition_penalty: Penalty for repeated tokens. Values >1 discourage repetition. + no_repeat_ngram_size: Prevent repetition of n-grams of this size. + temperatures: Sampling temperatures. Multiple values enable best-of sampling. + Lower values make output more deterministic. + compression_ratio_threshold: Threshold for detecting compression ratio issues (hallucinations). + If exceeded, sample is discarded. + log_prob_threshold: Threshold for log probability (confidence). Samples below this are discarded. + no_speech_threshold: Threshold for no-speech detection. Samples above this are considered non-speech. + condition_on_previous_text: Whether to condition on previous text for continuity. + prompt_reset_on_temperature: Temperature threshold above which to reset prompt. + Helps reduce hallucinations in creative sampling. + initial_prompt: Initial text prompt to guide transcription. + prefix: Text prefix to prepend to the decoded output. + suppress_blank: Suppress blank tokens in output. + suppress_tokens: Token IDs to suppress. + without_timestamps: Disable timestamp prediction during transcription. + max_initial_timestamp: Maximum initial timestamp allowed. + word_timestamps: Enable word-level timestamps. + prepend_punctuations: Punctuation marks that should be prepended to the following word during alignment. + append_punctuations: Punctuation marks that should be appended to the preceding word during alignment. + multilingual: Whether the model supports multilingual transcription. Inherited from model. + suppress_numerals: Suppress numeral and symbol tokens (e.g., digits, %). Useful for clean text. + Handled separately in pipeline. + max_new_tokens: Maximum new tokens to generate. + clip_timestamps: List of (start, end) timestamp pairs to segment the audio before processing. + hallucination_silence_threshold: Minimum silence duration (in seconds) to flag a segment as + potential hallucination. + hotwords: List of hotwords to boost in scoring. + """ + suppress_numerals: bool = False + beam_size: int = 5 + best_of: int = 5 + patience: float = 1.0 + length_penalty: float = 1.0 + repetition_penalty: float = 1.0 + no_repeat_ngram_size: int = 0 + temperatures: List[float] = field(default_factory=lambda: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) + compression_ratio_threshold: float = 2.4 + log_prob_threshold: float = -1.0 + no_speech_threshold: float = 0.6 + condition_on_previous_text: bool = False + prompt_reset_on_temperature: float = 0.5 + initial_prompt: Optional[str] = None + prefix: Optional[str] = None + suppress_blank: bool = True + suppress_tokens: List[int] = field(default_factory=lambda: [-1]) + without_timestamps: bool = True + max_initial_timestamp: float = 0.0 + word_timestamps: bool = False + prepend_punctuations: str = "\"'“¿([{-" + append_punctuations: str = "\"'.。,,!!??::”)]}、" + multilingual: bool = False # Set from model.is_multilingual + max_new_tokens: Optional[int] = None + clip_timestamps: Optional[List[Tuple[float, float]]] = None + hallucination_silence_threshold: Optional[float] = None + hotwords: Optional[List[str]] = None + + class WhisperModel(faster_whisper.WhisperModel): ''' FasterWhisperModel provides batched inference for faster-whisper. @@ -45,8 +117,11 @@ def generate_segment_batched( all_tokens = [] prompt_reset_since = 0 if options.initial_prompt is not None: - initial_prompt = " " + options.initial_prompt.strip() - initial_prompt_tokens = tokenizer.encode(initial_prompt) + if isinstance(options.initial_prompt, str): + initial_prompt = " " + options.initial_prompt.strip() + initial_prompt_tokens = tokenizer.encode(initial_prompt) + else: + initial_prompt_tokens = list(options.initial_prompt) all_tokens.extend(initial_prompt_tokens) previous_tokens = all_tokens[prompt_reset_since:] prompt = self.get_prompt( @@ -98,6 +173,7 @@ def encode(self, features: np.ndarray) -> ctranslate2.StorageView: return self.model.encode(features, to_cpu=to_cpu) + class FasterWhisperPipeline(Pipeline): """ Huggingface Pipeline wrapper for FasterWhisperModel. @@ -324,7 +400,7 @@ def load_model( compute_type - The compute type to use for the model. vad_model - The vad model to manually assign. vad_method - The vad method to use. vad_model has a higher priority if it is not None. - options - A dictionary of options to use for the model. + asr_options - A dictionary of options to override defaults in DefaultASRTranscriptionOptions. language - The language of the model. (use English for now) model - The WhisperModel instance to use. download_root - The root directory to download the model to. @@ -350,43 +426,19 @@ def load_model( logger.info("No language specified, language will be detected for each audio file (increases inference time)") tokenizer = None - default_asr_options = { - "beam_size": 5, - "best_of": 5, - "patience": 1, - "length_penalty": 1, - "repetition_penalty": 1, - "no_repeat_ngram_size": 0, - "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], - "compression_ratio_threshold": 2.4, - "log_prob_threshold": -1.0, - "no_speech_threshold": 0.6, - "condition_on_previous_text": False, - "prompt_reset_on_temperature": 0.5, - "initial_prompt": None, - "prefix": None, - "suppress_blank": True, - "suppress_tokens": [-1], - "without_timestamps": True, - "max_initial_timestamp": 0.0, - "word_timestamps": False, - "prepend_punctuations": "\"'“¿([{-", - "append_punctuations": "\"'.。,,!!??::”)]}、", - "multilingual": model.model.is_multilingual, - "suppress_numerals": False, - "max_new_tokens": None, - "clip_timestamps": None, - "hallucination_silence_threshold": None, - "hotwords": None, - } + transcription_options = DefaultASRTranscriptionOptions( + multilingual=model.model.is_multilingual, + ) if asr_options is not None: - default_asr_options.update(asr_options) - - suppress_numerals = default_asr_options["suppress_numerals"] - del default_asr_options["suppress_numerals"] + transcription_options = replace(transcription_options, **asr_options) - default_asr_options = TranscriptionOptions(**default_asr_options) + # Extract suppress_numerals (pipeline-specific, not part of base TranscriptionOptions) + suppress_numerals = transcription_options.suppress_numerals + # Cast to base TranscriptionOptions for compatibility (excludes suppress_numerals) + transcription_options = TranscriptionOptions( + **{k: v for k, v in transcription_options.__dict__.items() if k != 'suppress_numerals'} + ) default_vad_options = { "chunk_size": 30, # needed by silero since binarization happens before merge_chunks @@ -416,9 +468,9 @@ def load_model( return FasterWhisperPipeline( model=model, vad=vad_model, - options=default_asr_options, + options=transcription_options, tokenizer=tokenizer, language=language, suppress_numerals=suppress_numerals, vad_params=default_vad_options, - ) + ) \ No newline at end of file diff --git a/whisperx/diarize.py b/whisperx/diarize.py index 9f46b028..328d59d1 100644 --- a/whisperx/diarize.py +++ b/whisperx/diarize.py @@ -16,7 +16,7 @@ def __init__( self, model_name=None, use_auth_token=None, - device: Optional[Union[str, torch.device]] = "cpu", + device: Union[str, torch.device] = "cpu", ): if isinstance(device, str): device = torch.device(device) diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index 04c2ab36..61553b5c 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -4,6 +4,7 @@ import warnings import numpy as np +import pandas as pd import torch from whisperx.alignment import align, load_align_model @@ -186,7 +187,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): result["language"], device ) logger.info("Performing alignment...") - result: AlignedTranscriptionResult = align( + aligned_result: AlignedTranscriptionResult = align( result["segments"], align_model, align_metadata, @@ -196,6 +197,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): return_char_alignments=return_char_alignments, print_progress=print_progress, ) + result = aligned_result results.append((result, audio_path)) @@ -217,20 +219,28 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser): diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device) for result, input_audio_path in tmp_results: diarize_result = diarize_model( - input_audio_path, - min_speakers=min_speakers, - max_speakers=max_speakers, + input_audio_path, + min_speakers=min_speakers, + max_speakers=max_speakers, return_embeddings=return_speaker_embeddings ) if return_speaker_embeddings: + if not isinstance(diarize_result, tuple): + raise TypeError( + f"Expected tuple when return_embeddings=True, got {type(diarize_result).__name__}" + ) diarize_segments, speaker_embeddings = diarize_result else: + if not isinstance(diarize_result, pd.DataFrame): + raise TypeError( + f"Expected DataFrame when return_embeddings=False, got {type(diarize_result).__name__}" + ) diarize_segments = diarize_result speaker_embeddings = None - result = assign_word_speakers(diarize_segments, result, speaker_embeddings) - results.append((result, input_audio_path)) + diarized_result = assign_word_speakers(diarize_segments, result, speaker_embeddings) + results.append((diarized_result, input_audio_path)) # >> Write for result, audio_path in results: result["language"] = align_language diff --git a/whisperx/utils.py b/whisperx/utils.py index ada0deb9..711bbbb4 100644 --- a/whisperx/utils.py +++ b/whisperx/utils.py @@ -425,9 +425,9 @@ def get_writer( if output_format == "all": all_writers = [writer(output_dir) for writer in writers.values()] - def write_all(result: dict, file: str, options: dict): + def write_all(result: dict, audio_path: str, options: dict): for writer in all_writers: - writer(result, file, options) + writer(result, audio_path, options) return write_all @@ -435,6 +435,7 @@ def write_all(result: dict, file: str, options: dict): return optional_writers[output_format](output_dir) return writers[output_format](output_dir) + def interpolate_nans(x, method='nearest'): if x.notnull().sum() > 1: return x.interpolate(method=method).ffill().bfill()