2121# SOFTWARE.
2222
2323import os
24- from pathlib import Path
2524
2625from pyannote .audio import Pipeline
26+ from pyannote .audio .core .io import AudioFile
2727from pyannote .core import Annotation , Segment
2828
29+ from ..speaker_diarization import DiarizeOutput
30+
2931
3032class Local (Pipeline ):
3133 """Wrapper around official pyannoteAI on-premise package
@@ -51,29 +53,24 @@ def __init__(self, token: str | None = None, **kwargs):
5153 self .token = token or os .environ .get ("PYANNOTEAI_API_KEY" , None )
5254 self ._pipeline = _LocalPipeline (self .token )
5355
54- def _to_annotation (self , completed_job : dict ) -> Annotation :
55- """Deserialize job output into pyannote.core Annotation"""
56-
57- output = completed_job ["output" ]["diarization" ]
58- job_id = completed_job ["jobId" ]
59-
60- annotation = Annotation (uri = job_id )
61- for t , turn in enumerate (output ):
56+ def _deserialize (self , diarization : list [dict ]) -> Annotation :
57+ # deserialize the output into a good-old Annotation instance
58+ annotation = Annotation ()
59+ for t , turn in enumerate (diarization ):
6260 segment = Segment (start = turn ["start" ], end = turn ["end" ])
6361 speaker = turn ["speaker" ]
6462 annotation [segment , t ] = speaker
65-
6663 return annotation .rename_tracks ("string" )
6764
6865 def apply (
6966 self ,
70- file : Path ,
67+ file : AudioFile ,
7168 num_speakers : int | None = None ,
7269 min_speakers : int | None = None ,
7370 max_speakers : int | None = None ,
74- exclusive : bool = False ,
75- ) -> Annotation :
76- """Speaker diarization using pyannoteAI on-premise package
71+ ** kwargs ,
72+ ) -> DiarizeOutput :
73+ """Speaker diarization using on-premise pyannoteAI package.
7774
7875 Parameters
7976 ----------
@@ -86,41 +83,45 @@ def apply(
8683 Not supported yet. Minimum number of speakers. Has no effect when `num_speakers` is provided.
8784 max_speakers : int, optional
8885 Not supported yet. Maximum number of speakers. Has no effect when `num_speakers` is provided.
89- exclusive : bool, optional
90- Enable exclusive diarization.
9186
9287 Returns
9388 -------
94- speaker_diarization : Annotation
95- Speaker diarization result (when successful)
96-
97- Raises
98- ------
99- PyannoteAIFailedJob
100- If the job failed
101- PyannoteAICanceledJob
102- If the job was canceled
103- HTTPError
104- If something else went wrong
89+ output : DiarizeOutput
90+ DiarizeOutput object containing both regular and exclusive speaker diarization results.
10591 """
10692
107- predictions = self ._pipeline .diarize (
108- file ["audio" ],
109- num_speakers = num_speakers ,
110- min_speakers = min_speakers ,
111- max_speakers = max_speakers ,
112- )
113-
114- # use exclusive diarization whenever requested
115- if exclusive :
116- diarization = predictions ["exclusive_diarization" ]
93+ # if file provides "audio" path
94+ if "audio" in file :
95+ predictions = self ._pipeline .diarize (
96+ file ["audio" ],
97+ num_speakers = num_speakers ,
98+ min_speakers = min_speakers ,
99+ max_speakers = max_speakers ,
100+ ** kwargs ,
101+ )
102+
103+ # if file provides "waveform", make sure it is numpy (and not torch) array
104+ elif "waveform" in file :
105+ waveform = file ["waveform" ]
106+ if hasattr (waveform , "numpy" ):
107+ waveform = waveform .numpy (force = True )
108+
109+ predictions = self ._pipeline .diarize (
110+ {"waveform" : waveform , "sample_rate" : file ["sample_rate" ]},
111+ num_speakers = num_speakers ,
112+ min_speakers = min_speakers ,
113+ max_speakers = max_speakers ,
114+ ** kwargs ,
115+ )
117116 else :
118- diarization = predictions [ "diarization" ]
117+ raise ValueError ( "AudioFile must provide either 'audio' or 'waveform' key" )
119118
120- # deserialize the output into a good-old Annotation instance
121- annotation = Annotation ()
122- for t , turn in enumerate (diarization ):
123- segment = Segment (start = turn ["start" ], end = turn ["end" ])
124- speaker = turn ["speaker" ]
125- annotation [segment , t ] = speaker
126- return annotation .rename_tracks ("string" )
119+ speaker_diarization : Annotation = self ._deserialize (predictions ["diarization" ])
120+ exclusive_speaker_diarization : Annotation = self ._deserialize (
121+ predictions ["exclusive_diarization" ]
122+ )
123+
124+ return DiarizeOutput (
125+ speaker_diarization = speaker_diarization ,
126+ exclusive_speaker_diarization = exclusive_speaker_diarization ,
127+ )
0 commit comments