Skip to content

Commit f27ec3d

Browse files
authored
BREAKING: update STMLoader to return meeteval.io.SegLST instances (#113)
1 parent 3315062 commit f27ec3d

File tree

4 files changed

+213
-137
lines changed

4 files changed

+213
-137
lines changed

doc/source/changelog.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
Changelog
33
#########
44

5+
Version "next" (xxxx-xx-xx)
6+
~~~~~~~~~~~~~~~~~~~~~~~~~~
7+
8+
- BREAKING: update `STMLoader` to return `meeteval.io.SegLST` instances
9+
510
Version 6.0.0 (2025-09-09)
611
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
712

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ pyannote-database = "pyannote.database.cli:main"
2323
[project.entry-points.'pyannote.database.loader']
2424
".rttm" = "pyannote.database.loader:RTTMLoader"
2525
".uem" = "pyannote.database.loader:UEMLoader"
26-
".ctm" = "pyannote.database.loader:CTMLoader"
2726
".map" = "pyannote.database.loader:MAPLoader"
2827
".lab" = "pyannote.database.loader:LABLoader"
2928
".stm" = "pyannote.database.loader:STMLoader"
@@ -39,6 +38,9 @@ doc = [
3938
"sphinx-rtd-theme>=3.0.2",
4039
"sphinx>=8.1.3",
4140
]
41+
transcription = [
42+
"meeteval>=0.4.3",
43+
]
4244

4345
[build-system]
4446
requires = ["hatchling", "hatch-vcs"]

src/pyannote/database/loader.py

Lines changed: 53 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
# The MIT License (MIT)
55

6-
# Copyright (c) 2020- CNRS
6+
# Copyright (c) 2020-2025 CNRS
7+
# Copyright (c) 2025- pyannoteAI
78

89
# Permission is hereby granted, free of charge, to any person obtaining a copy
910
# of this software and associated documentation files (the "Software"), to deal
@@ -24,32 +25,29 @@
2425
# SOFTWARE.
2526

2627
# AUTHORS
27-
# Hervé BREDIN - http://herve.niderb.fr
28+
# Hervé BREDIN
2829
# Vincent BRIGNATZ
2930

3031
"""Data loaders"""
3132

32-
from typing import Text
33-
from pathlib import Path
3433
import string
35-
from pyannote.database.util import load_rttm, load_uem, load_lab, load_stm
36-
import pandas as pd
37-
from pyannote.core import Segment, Timeline, Annotation
38-
from pyannote.database.protocol.protocol import ProtocolFile
39-
from typing import Union, Any
4034
import warnings
35+
from pathlib import Path
36+
from typing import Any, Text
4137

38+
import pandas as pd
39+
from pyannote.core import Annotation, Timeline
40+
from pyannote.database.protocol.protocol import ProtocolFile
41+
from pyannote.database.util import load_lab, load_rttm, load_uem
4242

4343
try:
44-
from spacy.tokens import Token
45-
from spacy.tokens import Doc
44+
import meeteval.io
45+
from meeteval.io.seglst import SegLST
4646

47-
Token.set_extension("time_start", default=None)
48-
Token.set_extension("time_end", default=None)
49-
Token.set_extension("confidence", default=0.0)
47+
MEETEVAL_IS_AVAILABLE = True
5048

5149
except ImportError:
52-
pass
50+
MEETEVAL_IS_AVAILABLE = False
5351

5452

5553
def load_lst(file_lst):
@@ -89,9 +87,7 @@ def load_trial(file_trial):
8987
List of trial
9088
"""
9189

92-
trials = pd.read_table(
93-
file_trial, sep="\s+", names=["reference", "uri1", "uri2"]
94-
)
90+
trials = pd.read_table(file_trial, sep="\s+", names=["reference", "uri1", "uri2"])
9591

9692
for _, reference, uri1, uri2 in trials.itertuples():
9793
yield {"reference": reference, "uri1": uri1, "uri2": uri2}
@@ -119,7 +115,6 @@ def __init__(self, path: Text = None):
119115
self.loaded_ = dict() if self.placeholders_ else load_rttm(self.path)
120116

121117
def __call__(self, file: ProtocolFile) -> Annotation:
122-
123118
uri = file["uri"]
124119

125120
if uri in self.loaded_:
@@ -145,37 +140,66 @@ def __call__(self, file: ProtocolFile) -> Annotation:
145140
class STMLoader:
146141
"""STM loader
147142
148-
Can be used as a preprocessor.
149-
150143
Parameters
151144
----------
152145
path : str
153-
Path to STM file with optional ProtocolFile key placeholders
146+
Path to STM file with optional AudioFile key placeholders
154147
(e.g. "/path/to/{database}/{subset}/{uri}.stm")
155148
"""
156149

157-
def __init__(self, path: Text = None):
150+
def __init__(self, path: str | Path | None = None):
158151
super().__init__()
159152

160153
self.path = str(path)
161154

162155
_, placeholders, _, _ = zip(*string.Formatter().parse(self.path))
163156
self.placeholders_ = set(placeholders) - set([None])
164-
self.loaded_ = dict() if self.placeholders_ else load_stm(self.path)
165157

166-
def __call__(self, file: ProtocolFile) -> Annotation:
158+
if self.placeholders_:
159+
self.loaded_: dict[str, "SegLST"] = dict()
160+
return
167161

162+
if MEETEVAL_IS_AVAILABLE:
163+
seglst: SegLST = meeteval.io.load(self.path, format="stm").to_seglst()
164+
session_ids = set(s["session_id"] for s in seglst)
165+
self.loaded_: dict[str, SegLST] = {
166+
session_id: SegLST([s for s in seglst if s["session_id"] == session_id])
167+
for session_id in session_ids
168+
}
169+
170+
return
171+
172+
warnings.warn("MeetEval is not available, STM files cannot be loaded.")
173+
self.loaded_: dict[str, "SegLST"] = dict()
174+
175+
def __call__(self, file: ProtocolFile) -> "SegLST":
168176
uri = file["uri"]
169177

170178
if uri in self.loaded_:
171179
return self.loaded_[uri]
172180

173181
sub_file = {key: file[key] for key in self.placeholders_}
174-
loaded = load_stm(self.path.format(**sub_file))
182+
183+
if MEETEVAL_IS_AVAILABLE:
184+
seglst: SegLST = meeteval.io.load(
185+
self.path.format(**sub_file), format="stm"
186+
).to_seglst()
187+
session_ids = set(s["session_id"] for s in seglst)
188+
loaded: dict[str, SegLST] = {
189+
session_id: SegLST([s for s in seglst if s["session_id"] == session_id])
190+
for session_id in session_ids
191+
}
192+
else:
193+
warnings.warn("MeetEval is not available, STM files cannot be loaded.")
194+
loaded = dict()
195+
175196
if uri not in loaded:
176-
loaded[uri] = Annotation(uri=uri)
197+
if MEETEVAL_IS_AVAILABLE:
198+
loaded[uri] = SegLST([])
199+
else:
200+
loaded[uri] = None
177201

178-
# do not cache annotations when there is one STM file per "uri"
202+
# do not cache transcription when there is one STM file per "uri"
179203
# since loading it should be quite fast
180204
if "uri" in self.placeholders_:
181205
return loaded[uri]
@@ -209,7 +233,6 @@ def __init__(self, path: Text = None):
209233
self.loaded_ = dict() if self.placeholders_ else load_uem(self.path)
210234

211235
def __call__(self, file: ProtocolFile) -> Timeline:
212-
213236
uri = file["uri"]
214237

215238
if uri in self.loaded_:
@@ -261,65 +284,12 @@ def __init__(self, path: Text = None):
261284
raise ValueError("`path` must contain the {uri} placeholder.")
262285

263286
def __call__(self, file: ProtocolFile) -> Annotation:
264-
265287
uri = file["uri"]
266288

267289
sub_file = {key: file[key] for key in self.placeholders_}
268290
return load_lab(self.path.format(**sub_file), uri=uri)
269291

270292

271-
class CTMLoader:
272-
"""CTM loader
273-
274-
Parameter
275-
---------
276-
ctm : Path
277-
Path to CTM file
278-
"""
279-
280-
def __init__(self, ctm: Path):
281-
self.ctm = ctm
282-
283-
names = ["uri", "channel", "start", "duration", "word", "confidence"]
284-
dtype = {
285-
"uri": str,
286-
"start": float,
287-
"duration": float,
288-
"word": str,
289-
"confidence": float,
290-
}
291-
self.data_ = pd.read_csv(
292-
ctm, names=names, dtype=dtype, sep="\s+"
293-
).groupby("uri")
294-
295-
def __call__(self, current_file: ProtocolFile) -> Union["Doc", None]:
296-
297-
try:
298-
from spacy.vocab import Vocab
299-
from spacy.tokens import Doc
300-
except ImportError:
301-
msg = "Cannot load CTM files because spaCy is not available."
302-
warnings.warn(msg)
303-
return None
304-
305-
uri = current_file["uri"]
306-
307-
try:
308-
lines = list(self.data_.get_group(uri).iterrows())
309-
except KeyError:
310-
lines = []
311-
312-
words = [line.word for _, line in lines]
313-
doc = Doc(Vocab(), words=words)
314-
315-
for token, (_, line) in zip(doc, lines):
316-
token._.time_start = line.start
317-
token._.time_end = line.start + line.duration
318-
token._.confidence = line.confidence
319-
320-
return doc
321-
322-
323293
class MAPLoader:
324294
"""Mapping loader
325295
@@ -353,9 +323,7 @@ def __init__(self, mapping: Path):
353323
dtype = {
354324
"uri": str,
355325
}
356-
self.data_ = pd.read_csv(
357-
mapping, names=names, dtype=dtype, sep="\s+"
358-
)
326+
self.data_ = pd.read_csv(mapping, names=names, dtype=dtype, sep="\s+")
359327

360328
# get colum 'value' dtype, allowing us to acces it during subset
361329
self.dtype = self.data_.dtypes["value"]

0 commit comments

Comments
 (0)