Skip to content

Commit 21a23cb

Browse files
committed
Add workflow for faster-whisper (ctranslate2)
1 parent a4d9430 commit 21a23cb

File tree

3 files changed

+363
-0
lines changed

3 files changed

+363
-0
lines changed

lhotse/bin/modes/workflows.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,121 @@ def annotate_with_whisper(
114114
writer.write(cut, flush=True)
115115

116116

117+
@workflows.command()
118+
@click.argument("out_cuts", type=click.Path(allow_dash=True))
119+
@click.option(
120+
"-m",
121+
"--recordings-manifest",
122+
type=click.Path(exists=True, dir_okay=False, allow_dash=True),
123+
help="Path to an existing recording manifest.",
124+
)
125+
@click.option(
126+
"-r",
127+
"--recordings-dir",
128+
type=click.Path(exists=True, file_okay=False),
129+
help="Directory with recordings. We will create a RecordingSet for it automatically.",
130+
)
131+
@click.option(
132+
"-c",
133+
"--cuts-manifest",
134+
type=click.Path(exists=True, dir_okay=False, allow_dash=True),
135+
help="Path to an existing cuts manifest.",
136+
)
137+
@click.option(
138+
"-e",
139+
"--extension",
140+
default="wav",
141+
help="Audio file extension to search for. Used with RECORDINGS_DIR.",
142+
)
143+
@click.option(
144+
"-n",
145+
"--model-name",
146+
default="base",
147+
help="One of Whisper variants (base, medium, large, etc.)",
148+
)
149+
@click.option(
150+
"-l",
151+
"--language",
152+
help="Language spoken in the audio. Inferred by default.",
153+
)
154+
@click.option(
155+
"-d", "--device", default="cpu", help="Device on which to run the inference."
156+
)
157+
@click.option(
158+
"--device-index", default=0, help="Device index on which to run the inference."
159+
)
160+
@click.option(
161+
"--cpu-threads", default=0, help="Number of threads to use when running on CPU."
162+
)
163+
@click.option(
164+
"--num-workers", default=1, help="Number of workers for parallelizing across multiple GPUs."
165+
)
166+
@click.option("-j", "--jobs", default=1, help="Number of jobs for audio scanning.")
167+
@click.option(
168+
"--force-nonoverlapping/--keep-overlapping",
169+
default=False,
170+
help="If True, the Whisper segment time-stamps will be processed to make sure they are non-overlapping.",
171+
)
172+
def annotate_with_faster_whisper(
173+
out_cuts: str,
174+
recordings_manifest: Optional[str],
175+
recordings_dir: Optional[str],
176+
cuts_manifest: Optional[str],
177+
extension: str,
178+
model_name: str,
179+
language: Optional[str],
180+
device: str,
181+
device_index: int,
182+
cpu_threads: int,
183+
num_workers: int,
184+
jobs: int,
185+
force_nonoverlapping: bool,
186+
):
187+
"""
188+
Use OpenAI Whisper model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST.
189+
It will perform automatic segmentation, transcription, and language identification.
190+
191+
RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive. If CUTS_MANIFEST
192+
is provided, its supervisions will be overwritten with the results of the inference.
193+
194+
Note: this is an experimental feature of Lhotse, and is not guaranteed to yield
195+
high quality of data.
196+
"""
197+
from lhotse import annotate_with_faster_whisper as annotate_with_whisper_
198+
199+
assert exactly_one_not_null(recordings_manifest, recordings_dir, cuts_manifest), (
200+
"Options RECORDINGS_MANIFEST, RECORDINGS_DIR, and CUTS_MANIFEST are mutually exclusive "
201+
"and at least one is required."
202+
)
203+
204+
if recordings_manifest is not None:
205+
manifest = RecordingSet.from_file(recordings_manifest)
206+
elif recordings_dir is not None:
207+
manifest = RecordingSet.from_dir(
208+
recordings_dir, pattern=f"*.{extension}", num_jobs=jobs
209+
)
210+
else:
211+
manifest = CutSet.from_file(cuts_manifest).to_eager()
212+
213+
with CutSet.open_writer(out_cuts) as writer:
214+
for cut in tqdm(
215+
annotate_with_whisper_(
216+
manifest,
217+
language=language,
218+
model_name=model_name,
219+
device=device,
220+
device_index=device_index,
221+
force_nonoverlapping=force_nonoverlapping,
222+
compute_type="float16",
223+
cpu_threads=cpu_threads,
224+
num_workers=num_workers,
225+
),
226+
total=len(manifest),
227+
desc="Annotating with faster-whisper",
228+
):
229+
writer.write(cut, flush=True)
230+
231+
117232
@workflows.command()
118233
@click.argument(
119234
"in_cuts", type=click.Path(exists=True, dir_okay=False, allow_dash=True)

lhotse/workflows/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .forced_alignment import align_with_torchaudio
22
from .meeting_simulation import *
33
from .whisper import annotate_with_whisper
4+
from .faster_whisper import annotate_with_faster_whisper

lhotse/workflows/faster_whisper.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
import logging
2+
from typing import Any, Generator, List, Optional, Union
3+
4+
import numpy as np
5+
6+
from lhotse import (
7+
CutSet,
8+
MonoCut,
9+
Recording,
10+
RecordingSet,
11+
SupervisionSegment,
12+
add_durations,
13+
)
14+
from lhotse.qa import trim_supervisions_to_recordings
15+
from lhotse.utils import fastcopy, is_module_available
16+
from lhotse.supervision import AlignmentItem
17+
18+
19+
def annotate_with_faster_whisper(
20+
manifest: Union[RecordingSet, CutSet],
21+
model_name: str = "base",
22+
device: str = "cpu",
23+
device_index: int = 0,
24+
force_nonoverlapping: bool = False,
25+
compute_type: str = "default",
26+
cpu_threads: int = 0,
27+
num_workers: int = 1,
28+
**decode_options,
29+
) -> Generator[MonoCut, None, None]:
30+
"""
31+
Use OpenAI Whisper model to annotate either RECORDINGS_MANIFEST, RECORDINGS_DIR, or CUTS_MANIFEST.
32+
It will perform automatic segmentation, transcription, and language identification. If
33+
the first argument is a CutSet, it will overwrite the supervisions with the results of the inference.
34+
35+
Note: this is an experimental feature of Lhotse, and is not guaranteed to yield
36+
high quality of data.
37+
38+
See the original repo for more details: https://github.com/guillaumekln/faster-whisper
39+
40+
:param manifest: a ``RecordingSet`` or ``CutSet`` object.
41+
:param language: specify the language if known upfront, otherwise it will be auto-detected.
42+
:param model_name: one of available Whisper variants (base, medium, large, etc.).
43+
:param device: Where to run the inference (cpu, cuda, etc.).
44+
:param force_nonoverlapping: if True, the Whisper segment time-stamps will be processed to make
45+
sure they are non-overlapping.
46+
:param download_root: if specified, the model will be downloaded to this directory. Otherwise,
47+
it will be downloaded to the default location specfied by whisper.
48+
:param decode_options: additional options to pass to the ``whisper.transcribe`` function.
49+
:return: a generator of cuts (use ``CutSet.open_writer()`` to write them).
50+
"""
51+
assert is_module_available("faster_whisper"), (
52+
"This function expects faster-whisper to be installed. "
53+
"You can install it via 'pip install faster-whisper' "
54+
"(see https://github.com/guillaumekln/faster-whisper/ for details)."
55+
)
56+
57+
if isinstance(manifest, RecordingSet):
58+
yield from _annotate_recordings(
59+
manifest,
60+
model_name,
61+
device,
62+
device_index,
63+
force_nonoverlapping,
64+
compute_type=compute_type,
65+
cpu_threads=cpu_threads,
66+
num_workers=num_workers,
67+
**decode_options,
68+
)
69+
elif isinstance(manifest, CutSet):
70+
yield from _annotate_cuts(
71+
manifest,
72+
model_name,
73+
device,
74+
device_index,
75+
force_nonoverlapping,
76+
compute_type=compute_type,
77+
cpu_threads=cpu_threads,
78+
num_workers=num_workers,
79+
**decode_options,
80+
)
81+
else:
82+
raise ValueError("The ``manifest`` must be either a RecordingSet or a CutSet.")
83+
84+
85+
def _annotate_recordings(
86+
recordings: RecordingSet,
87+
model_name: str,
88+
device: str,
89+
device_index: int,
90+
force_nonoverlapping: bool,
91+
compute_type: str = "default",
92+
cpu_threads: int = 0,
93+
num_workers: int = 1,
94+
**decode_options,
95+
):
96+
"""
97+
Helper function that annotates a RecordingSet with Whisper.
98+
"""
99+
from faster_whisper import WhisperModel
100+
101+
model = WhisperModel(
102+
model_name,
103+
device=device,
104+
device_index=device_index,
105+
compute_type=compute_type,
106+
cpu_threads=cpu_threads,
107+
num_workers=num_workers,
108+
)
109+
110+
for recording in recordings:
111+
if recording.num_channels > 1:
112+
logging.warning(
113+
f"Skipping recording '{recording.id}'. It has {recording.num_channels} channels, "
114+
f"but we currently only support mono input."
115+
)
116+
continue
117+
audio = np.squeeze(recording.resample(16000).load_audio())
118+
segments, info = model.transcribe(audio=audio, word_timestamps=True, vad_filter=True, **decode_options)
119+
# Create supervisions from segments while filtering out those with negative duration.
120+
supervisions = [
121+
SupervisionSegment(
122+
id=f"{recording.id}-{segment_id:06d}",
123+
recording_id=recording.id,
124+
start=round(segment.start, ndigits=8),
125+
duration=add_durations(
126+
segment.end, -segment.start, sampling_rate=16000
127+
),
128+
text=segment.text.strip(),
129+
language=info.language,
130+
).with_alignment(
131+
"word",
132+
[
133+
AlignmentItem(
134+
symbol=ws.word.strip(),
135+
start=ws.start,
136+
duration=(ws.end - ws.start),
137+
score=ws.probability,
138+
)
139+
for ws in segment.words
140+
]
141+
)
142+
for segment_id, segment in enumerate(segments)
143+
if segment.end - segment.start > 0
144+
]
145+
cut = recording.to_cut()
146+
if supervisions:
147+
supervisions = (
148+
_postprocess_timestamps(supervisions)
149+
if force_nonoverlapping
150+
else supervisions
151+
)
152+
cut.supervisions = list(
153+
trim_supervisions_to_recordings(
154+
recordings=recording, supervisions=supervisions, verbose=False
155+
)
156+
)
157+
yield cut
158+
159+
160+
def _annotate_cuts(
161+
cuts: CutSet,
162+
model_name: str,
163+
device: str,
164+
device_index: int,
165+
force_nonoverlapping: bool,
166+
download_root: Optional[str] = None,
167+
**decode_options,
168+
):
169+
"""
170+
Helper function that annotates a CutSet with Whisper.
171+
"""
172+
from faster_whisper import WhisperModel
173+
174+
model = WhisperModel(
175+
model_name,
176+
device=device,
177+
device_index=device_index,
178+
compute_type=compute_type,
179+
cpu_threads=cpu_threads,
180+
num_workers=num_workers,
181+
)
182+
183+
for cut in cuts:
184+
if cut.num_channels > 1:
185+
logging.warning(
186+
f"Skipping cut '{cut.id}'. It has {cut.num_channels} channels, "
187+
f"but we currently only support mono input."
188+
)
189+
continue
190+
audio = np.squeeze(cut.resample(16000).load_audio())
191+
segments, info = model.transcribe(audio=audio, word_timestamps=True, **decode_options)
192+
# Create supervisions from segments while filtering out those with negative duration.
193+
supervisions = [
194+
SupervisionSegment(
195+
id=f"{cut.id}-{segment_id:06d}",
196+
recording_id=cut.recording_id,
197+
start=round(segment.start, ndigits=8),
198+
duration=add_durations(
199+
min(segment.end, cut.duration),
200+
-segment.start,
201+
sampling_rate=16000,
202+
),
203+
text=segment.text.strip(),
204+
language=info.language,
205+
).with_alignment(
206+
"word",
207+
[
208+
AlignmentItem(
209+
symbol=ws.word.strip(),
210+
start=ws.start,
211+
duration=(ws.end - ws.start),
212+
score=ws.probability,
213+
)
214+
for ws in segment.words
215+
]
216+
)
217+
for segment_id, segment in enumerate(segments)
218+
if segment.end - segment.start > 0
219+
]
220+
new_cut = fastcopy(
221+
cut,
222+
supervisions=_postprocess_timestamps(supervisions)
223+
if force_nonoverlapping
224+
else supervisions,
225+
)
226+
yield new_cut
227+
228+
229+
def _postprocess_timestamps(supervisions: List[SupervisionSegment]):
230+
"""
231+
Whisper tends to have a lot of overlapping segments due to inaccurate end timestamps.
232+
Under a strong assumption that the input speech is non-overlapping, we can fix that
233+
by always truncating to the start timestamp of the next segment.
234+
"""
235+
from cytoolz import sliding_window
236+
237+
supervisions = sorted(supervisions, key=lambda s: s.start)
238+
239+
if len(supervisions) < 2:
240+
return supervisions
241+
out = []
242+
for cur, nxt in sliding_window(2, supervisions):
243+
if cur.end > nxt.start:
244+
cur = cur.trim(end=nxt.start)
245+
out.append(cur)
246+
out.append(nxt)
247+
return out

0 commit comments

Comments
 (0)