Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions whisperx/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math

from dataclasses import dataclass
from typing import Iterable, Optional, Union, List
from typing import Optional, Union, List

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -115,7 +115,7 @@ def load_align_model(language_code: str, device: str, model_name: Optional[str]


def align(
transcript: Iterable[SingleSegment],
transcript: List[SingleSegment],
model: torch.nn.Module,
align_model_metadata: dict,
audio: Union[str, np.ndarray, torch.Tensor],
Expand Down
134 changes: 93 additions & 41 deletions whisperx/asr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from typing import List, Optional, Union
from dataclasses import replace
from typing import List, Optional, Union, Tuple
from dataclasses import dataclass, replace, field

import ctranslate2
import faster_whisper
Expand Down Expand Up @@ -28,6 +28,78 @@ def find_numeral_symbol_tokens(tokenizer):
numeral_symbol_tokens.append(i)
return numeral_symbol_tokens


@dataclass
class DefaultASRTranscriptionOptions(TranscriptionOptions):
"""Default configuration options for Automatic Speech Recognition (ASR) transcription using Whisper models.

Extends `faster_whisper.transcribe.TranscriptionOptions` with documented defaults that balance
accuracy, speed, and output quality. Override these via the `asr_options` parameter in `load_model`.

Fields:
beam_size: Number of beams in beam search. Higher values increase accuracy but slow down inference.
best_of: Number of candidates when sampling multiple temperatures. Higher values improve quality
but increase computation.
patience: Beam search patience factor (early stopping). Values >1 allow longer sequences.
length_penalty: Exponential penalty to length (alpha in Google NMT). Adjusts for sequence length bias.
repetition_penalty: Penalty for repeated tokens. Values >1 discourage repetition.
no_repeat_ngram_size: Prevent repetition of n-grams of this size.
temperatures: Sampling temperatures. Multiple values enable best-of sampling.
Lower values make output more deterministic.
compression_ratio_threshold: Threshold for detecting compression ratio issues (hallucinations).
If exceeded, sample is discarded.
log_prob_threshold: Threshold for log probability (confidence). Samples below this are discarded.
no_speech_threshold: Threshold for no-speech detection. Samples above this are considered non-speech.
condition_on_previous_text: Whether to condition on previous text for continuity.
prompt_reset_on_temperature: Temperature threshold above which to reset prompt.
Helps reduce hallucinations in creative sampling.
initial_prompt: Initial text prompt to guide transcription.
prefix: Text prefix to prepend to the decoded output.
suppress_blank: Suppress blank tokens in output.
suppress_tokens: Token IDs to suppress.
without_timestamps: Disable timestamp prediction during transcription.
max_initial_timestamp: Maximum initial timestamp allowed.
word_timestamps: Enable word-level timestamps.
prepend_punctuations: Punctuation marks that should be prepended to the following word during alignment.
append_punctuations: Punctuation marks that should be appended to the preceding word during alignment.
multilingual: Whether the model supports multilingual transcription. Inherited from model.
suppress_numerals: Suppress numeral and symbol tokens (e.g., digits, %). Useful for clean text.
Handled separately in pipeline.
max_new_tokens: Maximum new tokens to generate.
clip_timestamps: List of (start, end) timestamp pairs to segment the audio before processing.
hallucination_silence_threshold: Minimum silence duration (in seconds) to flag a segment as
potential hallucination.
hotwords: List of hotwords to boost in scoring.
"""
suppress_numerals: bool = False
beam_size: int = 5
best_of: int = 5
patience: float = 1.0
length_penalty: float = 1.0
repetition_penalty: float = 1.0
no_repeat_ngram_size: int = 0
temperatures: List[float] = field(default_factory=lambda: [0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
compression_ratio_threshold: float = 2.4
log_prob_threshold: float = -1.0
no_speech_threshold: float = 0.6
condition_on_previous_text: bool = False
prompt_reset_on_temperature: float = 0.5
initial_prompt: Optional[str] = None
prefix: Optional[str] = None
suppress_blank: bool = True
suppress_tokens: List[int] = field(default_factory=lambda: [-1])
without_timestamps: bool = True
max_initial_timestamp: float = 0.0
word_timestamps: bool = False
prepend_punctuations: str = "\"'“¿([{-"
append_punctuations: str = "\"'.。,,!!??::”)]}、"
multilingual: bool = False # Set from model.is_multilingual
max_new_tokens: Optional[int] = None
clip_timestamps: Optional[List[Tuple[float, float]]] = None
hallucination_silence_threshold: Optional[float] = None
hotwords: Optional[List[str]] = None


class WhisperModel(faster_whisper.WhisperModel):
'''
FasterWhisperModel provides batched inference for faster-whisper.
Expand All @@ -45,8 +117,11 @@ def generate_segment_batched(
all_tokens = []
prompt_reset_since = 0
if options.initial_prompt is not None:
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
if isinstance(options.initial_prompt, str):
initial_prompt = " " + options.initial_prompt.strip()
initial_prompt_tokens = tokenizer.encode(initial_prompt)
else:
initial_prompt_tokens = list(options.initial_prompt)
all_tokens.extend(initial_prompt_tokens)
previous_tokens = all_tokens[prompt_reset_since:]
prompt = self.get_prompt(
Expand Down Expand Up @@ -98,6 +173,7 @@ def encode(self, features: np.ndarray) -> ctranslate2.StorageView:

return self.model.encode(features, to_cpu=to_cpu)


class FasterWhisperPipeline(Pipeline):
"""
Huggingface Pipeline wrapper for FasterWhisperModel.
Expand Down Expand Up @@ -324,7 +400,7 @@ def load_model(
compute_type - The compute type to use for the model.
vad_model - The vad model to manually assign.
vad_method - The vad method to use. vad_model has a higher priority if it is not None.
options - A dictionary of options to use for the model.
asr_options - A dictionary of options to override defaults in DefaultASRTranscriptionOptions.
language - The language of the model. (use English for now)
model - The WhisperModel instance to use.
download_root - The root directory to download the model to.
Expand All @@ -350,43 +426,19 @@ def load_model(
logger.info("No language specified, language will be detected for each audio file (increases inference time)")
tokenizer = None

default_asr_options = {
"beam_size": 5,
"best_of": 5,
"patience": 1,
"length_penalty": 1,
"repetition_penalty": 1,
"no_repeat_ngram_size": 0,
"temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0],
"compression_ratio_threshold": 2.4,
"log_prob_threshold": -1.0,
"no_speech_threshold": 0.6,
"condition_on_previous_text": False,
"prompt_reset_on_temperature": 0.5,
"initial_prompt": None,
"prefix": None,
"suppress_blank": True,
"suppress_tokens": [-1],
"without_timestamps": True,
"max_initial_timestamp": 0.0,
"word_timestamps": False,
"prepend_punctuations": "\"'“¿([{-",
"append_punctuations": "\"'.。,,!!??::”)]}、",
"multilingual": model.model.is_multilingual,
"suppress_numerals": False,
"max_new_tokens": None,
"clip_timestamps": None,
"hallucination_silence_threshold": None,
"hotwords": None,
}
transcription_options = DefaultASRTranscriptionOptions(
multilingual=model.model.is_multilingual,
)

if asr_options is not None:
default_asr_options.update(asr_options)

suppress_numerals = default_asr_options["suppress_numerals"]
del default_asr_options["suppress_numerals"]
transcription_options = replace(transcription_options, **asr_options)

default_asr_options = TranscriptionOptions(**default_asr_options)
# Extract suppress_numerals (pipeline-specific, not part of base TranscriptionOptions)
suppress_numerals = transcription_options.suppress_numerals
# Cast to base TranscriptionOptions for compatibility (excludes suppress_numerals)
transcription_options = TranscriptionOptions(
**{k: v for k, v in transcription_options.__dict__.items() if k != 'suppress_numerals'}
)

default_vad_options = {
"chunk_size": 30, # needed by silero since binarization happens before merge_chunks
Expand Down Expand Up @@ -416,9 +468,9 @@ def load_model(
return FasterWhisperPipeline(
model=model,
vad=vad_model,
options=default_asr_options,
options=transcription_options,
tokenizer=tokenizer,
language=language,
suppress_numerals=suppress_numerals,
vad_params=default_vad_options,
)
)
2 changes: 1 addition & 1 deletion whisperx/diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
self,
model_name=None,
use_auth_token=None,
device: Optional[Union[str, torch.device]] = "cpu",
device: Union[str, torch.device] = "cpu",
):
if isinstance(device, str):
device = torch.device(device)
Expand Down
22 changes: 16 additions & 6 deletions whisperx/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings

import numpy as np
import pandas as pd
import torch

from whisperx.alignment import align, load_align_model
Expand Down Expand Up @@ -186,7 +187,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
result["language"], device
)
logger.info("Performing alignment...")
result: AlignedTranscriptionResult = align(
aligned_result: AlignedTranscriptionResult = align(
result["segments"],
align_model,
align_metadata,
Expand All @@ -196,6 +197,7 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
return_char_alignments=return_char_alignments,
print_progress=print_progress,
)
result = aligned_result

results.append((result, audio_path))

Expand All @@ -217,20 +219,28 @@ def transcribe_task(args: dict, parser: argparse.ArgumentParser):
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device)
for result, input_audio_path in tmp_results:
diarize_result = diarize_model(
input_audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers,
input_audio_path,
min_speakers=min_speakers,
max_speakers=max_speakers,
return_embeddings=return_speaker_embeddings
)

if return_speaker_embeddings:
if not isinstance(diarize_result, tuple):
raise TypeError(
f"Expected tuple when return_embeddings=True, got {type(diarize_result).__name__}"
)
diarize_segments, speaker_embeddings = diarize_result
else:
if not isinstance(diarize_result, pd.DataFrame):
raise TypeError(
f"Expected DataFrame when return_embeddings=False, got {type(diarize_result).__name__}"
)
diarize_segments = diarize_result
speaker_embeddings = None

result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
results.append((result, input_audio_path))
diarized_result = assign_word_speakers(diarize_segments, result, speaker_embeddings)
results.append((diarized_result, input_audio_path))
# >> Write
for result, audio_path in results:
result["language"] = align_language
Expand Down
5 changes: 3 additions & 2 deletions whisperx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,16 +425,17 @@ def get_writer(
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]

def write_all(result: dict, file: str, options: dict):
def write_all(result: dict, audio_path: str, options: dict):
for writer in all_writers:
writer(result, file, options)
writer(result, audio_path, options)

return write_all

if output_format in optional_writers:
return optional_writers[output_format](output_dir)
return writers[output_format](output_dir)


def interpolate_nans(x, method='nearest'):
if x.notnull().sum() > 1:
return x.interpolate(method=method).ffill().bfill()
Expand Down