Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
262 changes: 259 additions & 3 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dataclasses import asdict, dataclass
from inspect import signature
from math import ceil
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
from typing import BinaryIO, Iterable, List, Optional, Sequence, Tuple, Union
from warnings import warn

import ctranslate2
Expand Down Expand Up @@ -253,7 +253,12 @@ def generate_segment_batched(

def transcribe(
self,
audio: Union[str, BinaryIO, np.ndarray],
audio: Union[
str,
BinaryIO,
np.ndarray,
Sequence[Union[str, BinaryIO, np.ndarray]],
],
language: Optional[str] = None,
task: str = "transcribe",
log_progress: bool = False,
Expand Down Expand Up @@ -296,7 +301,10 @@ def transcribe(
hotwords: Optional[str] = None,
language_detection_threshold: Optional[float] = 0.5,
language_detection_segments: int = 1,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
) -> Union[
Tuple[Iterable[Segment], TranscriptionInfo],
Tuple[List[List[Segment]], List[TranscriptionInfo]],
]:
"""transcribe audio in chunks in batched fashion and return with language info.

Arguments:
Expand Down Expand Up @@ -374,6 +382,46 @@ def transcribe(
- an instance of TranscriptionInfo
"""

if isinstance(audio, (list, tuple)):
return self._transcribe_multiple(
audio,
language=language,
task=task,
log_progress=log_progress,
beam_size=beam_size,
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
temperature=temperature,
compression_ratio_threshold=compression_ratio_threshold,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
condition_on_previous_text=condition_on_previous_text,
prompt_reset_on_temperature=prompt_reset_on_temperature,
initial_prompt=initial_prompt,
prefix=prefix,
suppress_blank=suppress_blank,
suppress_tokens=suppress_tokens,
without_timestamps=without_timestamps,
max_initial_timestamp=max_initial_timestamp,
word_timestamps=word_timestamps,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
multilingual=multilingual,
vad_filter=vad_filter,
vad_parameters=vad_parameters,
max_new_tokens=max_new_tokens,
chunk_length=chunk_length,
clip_timestamps=clip_timestamps,
hallucination_silence_threshold=hallucination_silence_threshold,
batch_size=batch_size,
hotwords=hotwords,
language_detection_threshold=language_detection_threshold,
language_detection_segments=language_detection_segments,
)

sampling_rate = self.model.feature_extractor.sampling_rate

if multilingual and not self.model.model.is_multilingual:
Expand Down Expand Up @@ -563,6 +611,214 @@ def transcribe(

return segments, info

def _transcribe_multiple(
self,
audios,
language,
task,
log_progress,
beam_size,
best_of,
patience,
length_penalty,
repetition_penalty,
no_repeat_ngram_size,
temperature,
compression_ratio_threshold,
log_prob_threshold,
no_speech_threshold,
condition_on_previous_text,
prompt_reset_on_temperature,
initial_prompt,
prefix,
suppress_blank,
suppress_tokens,
without_timestamps,
max_initial_timestamp,
word_timestamps,
prepend_punctuations,
append_punctuations,
multilingual,
vad_filter,
vad_parameters,
max_new_tokens,
chunk_length,
clip_timestamps,
hallucination_silence_threshold,
batch_size,
hotwords,
language_detection_threshold,
language_detection_segments,
):
sampling_rate = self.model.feature_extractor.sampling_rate

audio_arrays = []
chunks_metadata = []
durations = []
for a in audios:
if not isinstance(a, np.ndarray):
arr = decode_audio(a, sampling_rate=sampling_rate)
else:
arr = a
duration = arr.shape[0] / sampling_rate
audio_arrays.append(arr)
chunks_metadata.append({"offset": 0.0, "duration": duration})
durations.append(duration)

features = np.stack(
[pad_or_trim(self.model.feature_extractor(a)[..., :-1]) for a in audio_arrays]
)

languages = []
language_probs = []
all_language_probs_list = []

if language is None:
if not self.model.model.is_multilingual:
languages = ["en"] * len(audio_arrays)
language_probs = [1.0] * len(audio_arrays)
all_language_probs_list = [None] * len(audio_arrays)
language = "en"
language_probability = 1.0
all_language_probs = None
else:
for arr in audio_arrays:
lang, lang_prob, all_probs = self.model.detect_language(
arr,
language_detection_segments=language_detection_segments,
language_detection_threshold=language_detection_threshold,
)
self.model.logger.info(
"Detected language '%s' with probability %.2f",
lang,
lang_prob,
)
languages.append(lang)
language_probs.append(lang_prob)
all_language_probs_list.append(all_probs)

if len(set(languages)) != 1:
raise RuntimeError(
"All inputs must share the same detected language; "
"specify `language` to mix languages."
)

language = languages[0]
language_probability = language_probs[0]
all_language_probs = all_language_probs_list[0]
else:
if not self.model.model.is_multilingual and language != "en":
self.model.logger.warning(
"The current model is English-only but the language parameter is set to '%s'; "
"using 'en' instead." % language
)
language = "en"

languages = [language] * len(audio_arrays)
language_probs = [1.0] * len(audio_arrays)
all_language_probs_list = [None] * len(audio_arrays)
language_probability = 1.0
all_language_probs = None

tokenizer = Tokenizer(
self.model.hf_tokenizer,
self.model.model.is_multilingual,
task=task,
language=language,
)

options = TranscriptionOptions(
beam_size=beam_size,
best_of=best_of,
patience=patience,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
log_prob_threshold=log_prob_threshold,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
temperatures=(
temperature[:1]
if isinstance(temperature, (list, tuple))
else [temperature]
),
initial_prompt=initial_prompt,
prefix=prefix,
suppress_blank=suppress_blank,
suppress_tokens=(
get_suppressed_tokens(tokenizer, suppress_tokens)
if suppress_tokens
else suppress_tokens
),
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
max_new_tokens=max_new_tokens,
hotwords=hotwords,
word_timestamps=word_timestamps,
hallucination_silence_threshold=None,
condition_on_previous_text=False,
clip_timestamps="0",
prompt_reset_on_temperature=0.5,
multilingual=multilingual,
without_timestamps=without_timestamps,
max_initial_timestamp=0.0,
)

segments_list = []
pbar = tqdm(total=len(features), disable=not log_progress, position=0)
for i in range(0, len(features), batch_size):
results = self.forward(
features[i : i + batch_size],
tokenizer,
chunks_metadata[i : i + batch_size],
options,
)

for result in results:
segs = []
seg_idx = 0
for segment in result:
seg_idx += 1
segs.append(
Segment(
seek=segment["seek"],
id=seg_idx,
text=segment["text"],
start=round(segment["start"], 3),
end=round(segment["end"], 3),
words=(
None
if not options.word_timestamps
else [Word(**word) for word in segment["words"]]
),
tokens=segment["tokens"],
avg_logprob=segment["avg_logprob"],
no_speech_prob=segment["no_speech_prob"],
compression_ratio=segment["compression_ratio"],
temperature=options.temperatures[0],
)
)
segments_list.append(segs)
pbar.update(1)

pbar.close()
self.last_speech_timestamp = 0.0

infos = [
TranscriptionInfo(
language=languages[i],
language_probability=language_probs[i],
duration=durations[i],
duration_after_vad=durations[i],
transcription_options=options,
vad_options=vad_parameters,
all_language_probs=all_language_probs_list[i],
)
for i in range(len(audios))
]

return segments_list, infos

def _batched_segments_generator(
self, features, tokenizer, chunks_metadata, batch_size, options, log_progress
):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,21 @@ def test_batched_transcribe(physcisworks_path):
assert len(segments) > 7


def test_batched_transcribe_multiple_files(jfk_path):
model = WhisperModel("tiny")
pipeline = BatchedInferencePipeline(model=model)
results, infos = pipeline.transcribe([jfk_path, jfk_path], vad_filter=False)
assert len(results) == 2
assert len(infos) == 2
expected = (
" And so my fellow Americans ask not what your country can do for you, "
"ask what you can do for your country."
)
for segments, info in zip(results, infos):
assert info.language == "en"
assert len(segments) == 1
assert segments[0].text == expected

def test_empty_audio():
audio = np.asarray([], dtype="float32")
model = WhisperModel("tiny")
Expand Down