Skip to content

Commit 43bb258

Browse files
authored
feat: add CLI to apply (and benchmark) pretrained pipelines
1 parent fd91866 commit 43bb258

File tree

2 files changed

+222
-0
lines changed

2 files changed

+222
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Clipping and speaker/source alignment issues in speech separation pipeline have
3636
- feat(utils): add `FilterByNumberOfSpeakers` protocol files filter
3737
- feat(core): add `Calibration` class to calibrate logits/distances into probabilities
3838
- feat(metric): add `DetectionErrorRate`, `SegmentationErrorRate`, `DiarizationPrecision`, and `DiarizationRecall` metrics
39+
- feat(cli): add CLI to apply (and benchmark) pretrained pipelines
3940

4041
### Improvements
4142

pyannote/audio/__main__.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
4+
# MIT License
5+
#
6+
# Copyright (c) 2024- CNRS
7+
#
8+
# Permission is hereby granted, free of charge, to any person obtaining a copy
9+
# of this software and associated documentation files (the "Software"), to deal
10+
# in the Software without restriction, including without limitation the rights
11+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12+
# copies of the Software, and to permit persons to whom the Software is
13+
# furnished to do so, subject to the following conditions:
14+
#
15+
# The above copyright notice and this permission notice shall be included in all
16+
# copies or substantial portions of the Software.
17+
#
18+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24+
# SOFTWARE.
25+
26+
27+
import sys
28+
from contextlib import nullcontext
29+
from enum import Enum
30+
from pathlib import Path
31+
from typing import Optional
32+
33+
import pyannote.database
34+
import torch
35+
import typer
36+
from pyannote.core import Annotation
37+
from typing_extensions import Annotated
38+
39+
from pyannote.audio import Pipeline
40+
41+
42+
class Subset(str, Enum):
43+
train = "train"
44+
development = "development"
45+
test = "test"
46+
47+
48+
class Device(str, Enum):
49+
CPU = "cpu"
50+
CUDA = "cuda"
51+
MPS = "mps"
52+
AUTO = "auto"
53+
54+
55+
def parse_device(device: Device) -> torch.device:
56+
if device == Device.AUTO:
57+
if torch.cuda.is_available():
58+
device = Device.CUDA
59+
60+
elif torch.backends.mps.is_available():
61+
device = Device.MPS
62+
63+
else:
64+
device = Device.CPU
65+
66+
return torch.device(device.value)
67+
68+
69+
app = typer.Typer()
70+
71+
72+
# TODO: add option to download pretrained pipeline for later use without internet
73+
74+
75+
@app.command("apply")
76+
def apply(
77+
pipeline: Annotated[
78+
str,
79+
typer.Argument(
80+
help="Pretrained pipeline (e.g. pyannote/speaker-diarization-3.1)"
81+
),
82+
],
83+
audio: Annotated[
84+
Path,
85+
typer.Argument(
86+
help="Path to audio file",
87+
exists=True,
88+
file_okay=True,
89+
readable=True,
90+
),
91+
],
92+
into: Annotated[
93+
Path,
94+
typer.Option(
95+
help="Path to file where results are saved.",
96+
exists=False,
97+
dir_okay=False,
98+
file_okay=True,
99+
writable=True,
100+
resolve_path=True,
101+
),
102+
] = None,
103+
device: Annotated[
104+
Device, typer.Option(help="Accelerator to use (CPU, CUDA, MPS)")
105+
] = Device.AUTO,
106+
):
107+
"""
108+
Apply a pretrained PIPELINE to an AUDIO file
109+
"""
110+
111+
# load pretrained pipeline
112+
pretrained_pipeline = Pipeline.from_pretrained(pipeline)
113+
114+
# send pipeline to device
115+
torch_device = parse_device(device)
116+
pretrained_pipeline.to(torch_device)
117+
118+
# apply pipeline to audio file
119+
prediction: Annotation = pretrained_pipeline(audio)
120+
121+
# save (or print) results
122+
with open(into, "w") if into else nullcontext(sys.stdout) as rttm:
123+
prediction.write_rttm(rttm)
124+
125+
126+
@app.command("benchmark")
127+
def benchmark(
128+
pipeline: Annotated[
129+
str,
130+
typer.Argument(
131+
help="Pretrained pipeline (e.g. pyannote/speaker-diarization-3.1)"
132+
),
133+
],
134+
protocol: Annotated[
135+
str,
136+
typer.Argument(help="Benchmarked protocol"),
137+
],
138+
into: Annotated[
139+
Path,
140+
typer.Argument(
141+
help="Directory into which benchmark results are saved",
142+
exists=True,
143+
dir_okay=True,
144+
file_okay=False,
145+
writable=True,
146+
resolve_path=True,
147+
),
148+
],
149+
subset: Annotated[
150+
Subset,
151+
typer.Option(
152+
help="Benchmarked subset",
153+
case_sensitive=False,
154+
),
155+
] = Subset.test,
156+
device: Annotated[
157+
Device, typer.Option(help="Accelerator to use (CPU, CUDA, MPS)")
158+
] = Device.AUTO,
159+
registry: Annotated[
160+
Optional[Path],
161+
typer.Option(
162+
help="Loaded registry",
163+
exists=True,
164+
dir_okay=False,
165+
file_okay=True,
166+
readable=True,
167+
),
168+
] = None,
169+
):
170+
"""
171+
Benchmark a pretrained PIPELINE
172+
"""
173+
174+
# load pretrained pipeline
175+
pretrained_pipeline = Pipeline.from_pretrained(pipeline)
176+
177+
# send pipeline to device
178+
torch_device = parse_device(device)
179+
pretrained_pipeline.to(torch_device)
180+
181+
# load pipeline metric (when available)
182+
try:
183+
metric = pretrained_pipeline.get_metric()
184+
except NotImplementedError:
185+
metric = None
186+
187+
# load protocol from (optional) registry
188+
if registry:
189+
pyannote.database.registry.load_database(registry)
190+
191+
loaded_protocol = pyannote.database.registry.get_protocol(
192+
protocol, {"audio": pyannote.database.FileFinder()}
193+
)
194+
195+
with open(into / f"{protocol}.{subset.value}.rttm", "w") as rttm:
196+
for file in getattr(loaded_protocol, subset.value)():
197+
prediction: Annotation = pretrained_pipeline(file)
198+
prediction.write_rttm(rttm)
199+
rttm.flush()
200+
201+
if metric is None:
202+
continue
203+
204+
groundtruth = file.get("annotation", None)
205+
if groundtruth is None:
206+
continue
207+
208+
annotated = file.get("annotated", None)
209+
_ = metric(groundtruth, prediction, uem=annotated)
210+
211+
if metric is None:
212+
return
213+
214+
with open(into / f"{protocol}.{subset.value}.txt", "w") as txt:
215+
txt.write(str(metric))
216+
217+
print(str(metric))
218+
219+
220+
if __name__ == "__main__":
221+
app()

0 commit comments

Comments
 (0)