33
44# The MIT License (MIT)
55
6- # Copyright (c) 2020- CNRS
6+ # Copyright (c) 2020-2025 CNRS
7+ # Copyright (c) 2025- pyannoteAI
78
89# Permission is hereby granted, free of charge, to any person obtaining a copy
910# of this software and associated documentation files (the "Software"), to deal
2425# SOFTWARE.
2526
2627# AUTHORS
27- # Hervé BREDIN - http://herve.niderb.fr
28+ # Hervé BREDIN
2829# Vincent BRIGNATZ
2930
3031"""Data loaders"""
3132
32- from typing import Text
33- from pathlib import Path
3433import string
35- from pyannote .database .util import load_rttm , load_uem , load_lab , load_stm
36- import pandas as pd
37- from pyannote .core import Segment , Timeline , Annotation
38- from pyannote .database .protocol .protocol import ProtocolFile
39- from typing import Union , Any
4034import warnings
35+ from pathlib import Path
36+ from typing import Any , Text
4137
38+ import pandas as pd
39+ from pyannote .core import Annotation , Timeline
40+ from pyannote .database .protocol .protocol import ProtocolFile
41+ from pyannote .database .util import load_lab , load_rttm , load_uem
4242
4343try :
44- from spacy . tokens import Token
45- from spacy . tokens import Doc
44+ import meeteval . io
45+ from meeteval . io . seglst import SegLST
4646
47- Token .set_extension ("time_start" , default = None )
48- Token .set_extension ("time_end" , default = None )
49- Token .set_extension ("confidence" , default = 0.0 )
47+ MEETEVAL_IS_AVAILABLE = True
5048
5149except ImportError :
52- pass
50+ MEETEVAL_IS_AVAILABLE = False
5351
5452
5553def load_lst (file_lst ):
@@ -89,9 +87,7 @@ def load_trial(file_trial):
8987 List of trial
9088 """
9189
92- trials = pd .read_table (
93- file_trial , sep = "\s+" , names = ["reference" , "uri1" , "uri2" ]
94- )
90+ trials = pd .read_table (file_trial , sep = "\s+" , names = ["reference" , "uri1" , "uri2" ])
9591
9692 for _ , reference , uri1 , uri2 in trials .itertuples ():
9793 yield {"reference" : reference , "uri1" : uri1 , "uri2" : uri2 }
@@ -119,7 +115,6 @@ def __init__(self, path: Text = None):
119115 self .loaded_ = dict () if self .placeholders_ else load_rttm (self .path )
120116
121117 def __call__ (self , file : ProtocolFile ) -> Annotation :
122-
123118 uri = file ["uri" ]
124119
125120 if uri in self .loaded_ :
@@ -145,37 +140,66 @@ def __call__(self, file: ProtocolFile) -> Annotation:
145140class STMLoader :
146141 """STM loader
147142
148- Can be used as a preprocessor.
149-
150143 Parameters
151144 ----------
152145 path : str
153- Path to STM file with optional ProtocolFile key placeholders
146+ Path to STM file with optional AudioFile key placeholders
154147 (e.g. "/path/to/{database}/{subset}/{uri}.stm")
155148 """
156149
157- def __init__ (self , path : Text = None ):
150+ def __init__ (self , path : str | Path | None = None ):
158151 super ().__init__ ()
159152
160153 self .path = str (path )
161154
162155 _ , placeholders , _ , _ = zip (* string .Formatter ().parse (self .path ))
163156 self .placeholders_ = set (placeholders ) - set ([None ])
164- self .loaded_ = dict () if self .placeholders_ else load_stm (self .path )
165157
166- def __call__ (self , file : ProtocolFile ) -> Annotation :
158+ if self .placeholders_ :
159+ self .loaded_ : dict [str , "SegLST" ] = dict ()
160+ return
167161
162+ if MEETEVAL_IS_AVAILABLE :
163+ seglst : SegLST = meeteval .io .load (self .path , format = "stm" ).to_seglst ()
164+ session_ids = set (s ["session_id" ] for s in seglst )
165+ self .loaded_ : dict [str , SegLST ] = {
166+ session_id : SegLST ([s for s in seglst if s ["session_id" ] == session_id ])
167+ for session_id in session_ids
168+ }
169+
170+ return
171+
172+ warnings .warn ("MeetEval is not available, STM files cannot be loaded." )
173+ self .loaded_ : dict [str , "SegLST" ] = dict ()
174+
175+ def __call__ (self , file : ProtocolFile ) -> "SegLST" :
168176 uri = file ["uri" ]
169177
170178 if uri in self .loaded_ :
171179 return self .loaded_ [uri ]
172180
173181 sub_file = {key : file [key ] for key in self .placeholders_ }
174- loaded = load_stm (self .path .format (** sub_file ))
182+
183+ if MEETEVAL_IS_AVAILABLE :
184+ seglst : SegLST = meeteval .io .load (
185+ self .path .format (** sub_file ), format = "stm"
186+ ).to_seglst ()
187+ session_ids = set (s ["session_id" ] for s in seglst )
188+ loaded : dict [str , SegLST ] = {
189+ session_id : SegLST ([s for s in seglst if s ["session_id" ] == session_id ])
190+ for session_id in session_ids
191+ }
192+ else :
193+ warnings .warn ("MeetEval is not available, STM files cannot be loaded." )
194+ loaded = dict ()
195+
175196 if uri not in loaded :
176- loaded [uri ] = Annotation (uri = uri )
197+ if MEETEVAL_IS_AVAILABLE :
198+ loaded [uri ] = SegLST ([])
199+ else :
200+ loaded [uri ] = None
177201
178- # do not cache annotations when there is one STM file per "uri"
202+ # do not cache transcription when there is one STM file per "uri"
179203 # since loading it should be quite fast
180204 if "uri" in self .placeholders_ :
181205 return loaded [uri ]
@@ -209,7 +233,6 @@ def __init__(self, path: Text = None):
209233 self .loaded_ = dict () if self .placeholders_ else load_uem (self .path )
210234
211235 def __call__ (self , file : ProtocolFile ) -> Timeline :
212-
213236 uri = file ["uri" ]
214237
215238 if uri in self .loaded_ :
@@ -261,65 +284,12 @@ def __init__(self, path: Text = None):
261284 raise ValueError ("`path` must contain the {uri} placeholder." )
262285
263286 def __call__ (self , file : ProtocolFile ) -> Annotation :
264-
265287 uri = file ["uri" ]
266288
267289 sub_file = {key : file [key ] for key in self .placeholders_ }
268290 return load_lab (self .path .format (** sub_file ), uri = uri )
269291
270292
271- class CTMLoader :
272- """CTM loader
273-
274- Parameter
275- ---------
276- ctm : Path
277- Path to CTM file
278- """
279-
280- def __init__ (self , ctm : Path ):
281- self .ctm = ctm
282-
283- names = ["uri" , "channel" , "start" , "duration" , "word" , "confidence" ]
284- dtype = {
285- "uri" : str ,
286- "start" : float ,
287- "duration" : float ,
288- "word" : str ,
289- "confidence" : float ,
290- }
291- self .data_ = pd .read_csv (
292- ctm , names = names , dtype = dtype , sep = "\s+"
293- ).groupby ("uri" )
294-
295- def __call__ (self , current_file : ProtocolFile ) -> Union ["Doc" , None ]:
296-
297- try :
298- from spacy .vocab import Vocab
299- from spacy .tokens import Doc
300- except ImportError :
301- msg = "Cannot load CTM files because spaCy is not available."
302- warnings .warn (msg )
303- return None
304-
305- uri = current_file ["uri" ]
306-
307- try :
308- lines = list (self .data_ .get_group (uri ).iterrows ())
309- except KeyError :
310- lines = []
311-
312- words = [line .word for _ , line in lines ]
313- doc = Doc (Vocab (), words = words )
314-
315- for token , (_ , line ) in zip (doc , lines ):
316- token ._ .time_start = line .start
317- token ._ .time_end = line .start + line .duration
318- token ._ .confidence = line .confidence
319-
320- return doc
321-
322-
323293class MAPLoader :
324294 """Mapping loader
325295
@@ -353,9 +323,7 @@ def __init__(self, mapping: Path):
353323 dtype = {
354324 "uri" : str ,
355325 }
356- self .data_ = pd .read_csv (
357- mapping , names = names , dtype = dtype , sep = "\s+"
358- )
326+ self .data_ = pd .read_csv (mapping , names = names , dtype = dtype , sep = "\s+" )
359327
360328 # get colum 'value' dtype, allowing us to acces it during subset
361329 self .dtype = self .data_ .dtypes ["value" ]
0 commit comments