Skip to content

Commit cc059d6

Browse files
authored
improve(pyannoteAI): update on-premise wrapper to return both regular and exclusive diarization (#1953)
1 parent 121054b commit cc059d6

File tree

3 files changed

+49
-47
lines changed

3 files changed

+49
-47
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- feat(pipeline): add `preload` option to base `Pipeline.__call__` to force preloading audio in memory ([@antoinelaurent](https://github.com/antoinelaurent/))
99
- feat(pipeline): add `Pipeline.cuda()` convenience method [@tkanarsky](https://github.com/tkanarsky/)
1010
- improve(util): make `permutate` faster thanks to vectorized cost function
11+
- improve(pyannoteAI): update pyannoteAI wrapper to return both regular and exclusive diarization
1112

1213
## Version 4.0.1 (2025-10-10)
1314

src/pyannote/audio/pipelines/pyannoteai/local.py

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
# SOFTWARE.
2222

2323
import os
24-
from pathlib import Path
2524

2625
from pyannote.audio import Pipeline
26+
from pyannote.audio.core.io import AudioFile
2727
from pyannote.core import Annotation, Segment
2828

29+
from ..speaker_diarization import DiarizeOutput
30+
2931

3032
class Local(Pipeline):
3133
"""Wrapper around official pyannoteAI on-premise package
@@ -51,29 +53,24 @@ def __init__(self, token: str | None = None, **kwargs):
5153
self.token = token or os.environ.get("PYANNOTEAI_API_KEY", None)
5254
self._pipeline = _LocalPipeline(self.token)
5355

54-
def _to_annotation(self, completed_job: dict) -> Annotation:
55-
"""Deserialize job output into pyannote.core Annotation"""
56-
57-
output = completed_job["output"]["diarization"]
58-
job_id = completed_job["jobId"]
59-
60-
annotation = Annotation(uri=job_id)
61-
for t, turn in enumerate(output):
56+
def _deserialize(self, diarization: list[dict]) -> Annotation:
57+
# deserialize the output into a good-old Annotation instance
58+
annotation = Annotation()
59+
for t, turn in enumerate(diarization):
6260
segment = Segment(start=turn["start"], end=turn["end"])
6361
speaker = turn["speaker"]
6462
annotation[segment, t] = speaker
65-
6663
return annotation.rename_tracks("string")
6764

6865
def apply(
6966
self,
70-
file: Path,
67+
file: AudioFile,
7168
num_speakers: int | None = None,
7269
min_speakers: int | None = None,
7370
max_speakers: int | None = None,
74-
exclusive: bool = False,
75-
) -> Annotation:
76-
"""Speaker diarization using pyannoteAI on-premise package
71+
**kwargs,
72+
) -> DiarizeOutput:
73+
"""Speaker diarization using on-premise pyannoteAI package.
7774
7875
Parameters
7976
----------
@@ -86,41 +83,45 @@ def apply(
8683
Not supported yet. Minimum number of speakers. Has no effect when `num_speakers` is provided.
8784
max_speakers : int, optional
8885
Not supported yet. Maximum number of speakers. Has no effect when `num_speakers` is provided.
89-
exclusive : bool, optional
90-
Enable exclusive diarization.
9186
9287
Returns
9388
-------
94-
speaker_diarization : Annotation
95-
Speaker diarization result (when successful)
96-
97-
Raises
98-
------
99-
PyannoteAIFailedJob
100-
If the job failed
101-
PyannoteAICanceledJob
102-
If the job was canceled
103-
HTTPError
104-
If something else went wrong
89+
output : DiarizeOutput
90+
DiarizeOutput object containing both regular and exclusive speaker diarization results.
10591
"""
10692

107-
predictions = self._pipeline.diarize(
108-
file["audio"],
109-
num_speakers=num_speakers,
110-
min_speakers=min_speakers,
111-
max_speakers=max_speakers,
112-
)
113-
114-
# use exclusive diarization whenever requested
115-
if exclusive:
116-
diarization = predictions["exclusive_diarization"]
93+
# if file provides "audio" path
94+
if "audio" in file:
95+
predictions = self._pipeline.diarize(
96+
file["audio"],
97+
num_speakers=num_speakers,
98+
min_speakers=min_speakers,
99+
max_speakers=max_speakers,
100+
**kwargs,
101+
)
102+
103+
# if file provides "waveform", make sure it is numpy (and not torch) array
104+
elif "waveform" in file:
105+
waveform = file["waveform"]
106+
if hasattr(waveform, "numpy"):
107+
waveform = waveform.numpy(force=True)
108+
109+
predictions = self._pipeline.diarize(
110+
{"waveform": waveform, "sample_rate": file["sample_rate"]},
111+
num_speakers=num_speakers,
112+
min_speakers=min_speakers,
113+
max_speakers=max_speakers,
114+
**kwargs,
115+
)
117116
else:
118-
diarization = predictions["diarization"]
117+
raise ValueError("AudioFile must provide either 'audio' or 'waveform' key")
119118

120-
# deserialize the output into a good-old Annotation instance
121-
annotation = Annotation()
122-
for t, turn in enumerate(diarization):
123-
segment = Segment(start=turn["start"], end=turn["end"])
124-
speaker = turn["speaker"]
125-
annotation[segment, t] = speaker
126-
return annotation.rename_tracks("string")
119+
speaker_diarization: Annotation = self._deserialize(predictions["diarization"])
120+
exclusive_speaker_diarization: Annotation = self._deserialize(
121+
predictions["exclusive_diarization"]
122+
)
123+
124+
return DiarizeOutput(
125+
speaker_diarization=speaker_diarization,
126+
exclusive_speaker_diarization=exclusive_speaker_diarization,
127+
)

src/pyannote/audio/pipelines/pyannoteai/sdk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
# SOFTWARE.
2222

2323
import os
24-
from pathlib import Path
2524

2625
from pyannote.audio import Pipeline
26+
from pyannote.audio.core.io import AudioFile
2727
from pyannote.core import Annotation, Segment
2828

2929
from pyannoteai.sdk import Client
@@ -68,7 +68,7 @@ def _deserialize(self, diarization: list[dict]) -> Annotation:
6868

6969
def apply(
7070
self,
71-
file: Path,
71+
file: AudioFile,
7272
num_speakers: int | None = None,
7373
min_speakers: int | None = None,
7474
max_speakers: int | None = None,

0 commit comments

Comments
 (0)