Skip to content

Commit da5794b

Browse files
committed
feat(registry): make registry iterable (yields protocol names)
1 parent bac3225 commit da5794b

File tree

2 files changed

+55
-35
lines changed

2 files changed

+55
-35
lines changed

doc/source/changelog.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
Changelog
33
#########
44

5+
develop
6+
~~~~~~~
7+
8+
- feat(registry): make registry iterable (yields protocol names)
9+
510
Version 5.0.1 (2023-04-21)
611
~~~~~~~~~~~~~~~~~~~~~~~~~~
712

pyannote/database/registry.py

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@
3939
from .database import Database
4040
import yaml
4141

42+
4243
# controls what to do in case of protocol name conflict
4344
class LoadingMode(Enum):
4445
OVERRIDE = 0 # override existing protocol
45-
KEEP = 1 # keep existing protocol
46-
ERROR = 2 # raise an error
46+
KEEP = 1 # keep existing protocol
47+
ERROR = 2 # raise an error
48+
4749

4850
# To ease the understanding of future me, all comments inside Registry codebase
4951
# assume the existence of the following database.yml files.
@@ -52,7 +54,7 @@ class LoadingMode(Enum):
5254
# Content of /path/to/first/database.yml
5355
# ======================================
5456
# Databases:
55-
# DatabaseA:
57+
# DatabaseA:
5658
# - relative/path/A/trn/{uri}.wav
5759
# - relative/path/A/dev/{uri}.wav
5860
# - relative/path/A/tst/{uri}.wav
@@ -95,7 +97,7 @@ class LoadingMode(Enum):
9597
# DatabaseC:
9698
# SpeakerDiarization:
9799
# Protocol:
98-
# ...
100+
# ...
99101

100102

101103
class Registry:
@@ -109,13 +111,12 @@ class Registry:
109111
"""
110112

111113
def __init__(self) -> None:
112-
113114
# Mapping of database.yml paths to their config in a dictionary
114115
# Example after loading both database.yml:
115116
# {"/path/to/first/database.yml": {
116117
# "Databases":{
117118
# "DatabaseA": ["relative/path/A/trn/{uri}.wav", "relative/path/A/dev/{uri}.wav", relative/path/A/tst/{uri}.wav]
118-
# "DatabaseB": "/absolute/path/B/{uri}.wav"
119+
# "DatabaseB": "/absolute/path/B/{uri}.wav"
119120
# },
120121
# "Protocols":{
121122
# "DatabaseA":{
@@ -134,7 +135,7 @@ def __init__(self) -> None:
134135
# "/path/to/second/database.yml": {
135136
# "Databases":{
136137
# "DatabaseC": /absolute/path/C/{uri}.wav
137-
# "DatabaseB": "/absolute/path/B/{uri}.wav"
138+
# "DatabaseB": "/absolute/path/B/{uri}.wav"
138139
# },
139140
# "Protocols":{
140141
# "DatabaseB":{"SpeakerDiarization": {"Protocol": {...}}},
@@ -144,7 +145,6 @@ def __init__(self) -> None:
144145
# }
145146
self.configs: Dict[Path, Dict] = dict()
146147

147-
148148
# Content of the "Database" root item (= where to find file content)
149149
# Example after loading both database.yml:
150150
# {
@@ -176,7 +176,7 @@ def load_database(
176176
Parameters
177177
----------
178178
path : str or Path
179-
Path to YAML configuration file.
179+
Path to YAML configuration file.
180180
mode : LoadingMode, optional
181181
Controls how to handle conflicts in protocol names.
182182
Defaults to overriding the existing protocol.
@@ -210,7 +210,7 @@ def _load_database_helper(
210210
# make path absolute
211211
database_yml = Path(database_yml).expanduser().resolve()
212212

213-
# stop here if configuration file is already being loaded
213+
# stop here if configuration file is already being loaded
214214
# (possibly because of circular requirements)
215215
if database_yml in loading:
216216
return
@@ -221,9 +221,9 @@ def _load_database_helper(
221221
# load configuration
222222
with open(database_yml, "r") as f:
223223
config = yaml.load(f, Loader=yaml.SafeLoader)
224-
224+
225225
# load every requirement
226-
requirements = config.pop('Requirements', list())
226+
requirements = config.pop("Requirements", list())
227227
if not isinstance(requirements, list):
228228
requirements = [requirements]
229229
for requirement_yaml in requirements:
@@ -244,9 +244,7 @@ def _load_database_helper(
244244

245245
# load protocols of each database
246246
for db_name, db_entries in protocols.items():
247-
self._load_protocols(
248-
db_name, db_entries, database_yml, mode=mode
249-
)
247+
self._load_protocols(db_name, db_entries, database_yml, mode=mode)
250248

251249
# process "Databases" section
252250
databases = config.get("Databases", dict())
@@ -265,7 +263,6 @@ def _load_database_helper(
265263
# save configuration for later reloading of meta-protocols
266264
self.configs[database_yml] = config
267265

268-
269266
def get_database(self, database_name, **kwargs) -> Database:
270267
"""Get database by name
271268
@@ -284,7 +281,6 @@ def get_database(self, database_name, **kwargs) -> Database:
284281
database = self.databases[database_name]
285282

286283
except KeyError:
287-
288284
if database_name == "X":
289285
msg = (
290286
"Could not find any meta-protocol. Please refer to "
@@ -302,7 +298,9 @@ def get_database(self, database_name, **kwargs) -> Database:
302298

303299
return database(**kwargs)
304300

305-
def get_protocol(self, name, preprocessors: Optional[Preprocessors] = None) -> Protocol:
301+
def get_protocol(
302+
self, name, preprocessors: Optional[Preprocessors] = None
303+
) -> Protocol:
306304
"""Get protocol by full name
307305
308306
Parameters
@@ -329,6 +327,14 @@ def get_protocol(self, name, preprocessors: Optional[Preprocessors] = None) -> P
329327
protocol.name = name
330328
return protocol
331329

330+
# iterate over all protocols by name
331+
def __iter__(self):
332+
for database_name in self.databases:
333+
database = self.get_database(database_name)
334+
for task_name in database.get_tasks():
335+
for protocol_name in database.get_protocols(task_name):
336+
yield f"{database_name}.{task_name}.{protocol_name}"
337+
332338
def _load_protocols(
333339
self,
334340
db_name,
@@ -367,7 +373,9 @@ def _load_protocols(
367373
# If needed, merge old protocols dict with the new one (according to current override rules)
368374
if db_name in self.databases:
369375
old_protocols = self.databases[db_name]._protocols
370-
_merge_protocols_inplace(protocols, old_protocols, mode, db_name, database_yml)
376+
_merge_protocols_inplace(
377+
protocols, old_protocols, mode, db_name, database_yml
378+
)
371379

372380
# create database class on-the-fly
373381
protocol_list = [
@@ -389,13 +397,14 @@ def _reload_meta_protocols(self):
389397
for db_yml, config in self.configs.items():
390398
databases = config.get("Protocols", dict())
391399
if "X" in databases:
392-
self._load_protocols("X", databases["X"], db_yml, mode=LoadingMode.OVERRIDE)
393-
400+
self._load_protocols(
401+
"X", databases["X"], db_yml, mode=LoadingMode.OVERRIDE
402+
)
394403

395404

396405
def _env_config_paths() -> List[Path]:
397406
"""Parse PYANNOTE_DATABASE_CONFIG environment variable
398-
407+
399408
PYANNOTE_DATABASE_CONFIG may contain multiple paths separation by ";".
400409
401410
Returns
@@ -413,10 +422,11 @@ def _env_config_paths() -> List[Path]:
413422
paths.append(path)
414423
return paths
415424

425+
416426
def _find_default_ymls() -> List[Path]:
417427
"""Get paths to default YAML configuration files
418428
419-
* $HOME/.pyannote/database.yml
429+
* $HOME/.pyannote/database.yml
420430
* $CWD/database.yml
421431
* PYANNOTE_DATABASE_CONFIG environment variable
422432
@@ -431,21 +441,23 @@ def _find_default_ymls() -> List[Path]:
431441
home_db_yml = Path("~/.pyannote/database.yml").expanduser()
432442
if home_db_yml.is_file():
433443
paths.append(home_db_yml)
434-
444+
435445
cwd_db_yml = Path.cwd() / "database.yml"
436446
if cwd_db_yml.is_file():
437447
paths.append(cwd_db_yml)
438-
448+
439449
paths += _env_config_paths()
440450

441451
return paths
442452

453+
443454
def _merge_protocols_inplace(
444-
new_protocols: Dict[Tuple[Text, Text], Type],
445-
old_protocols: Dict[Tuple[Text, Text], Type],
446-
mode: LoadingMode,
447-
db_name: str,
448-
database_yml: str):
455+
new_protocols: Dict[Tuple[Text, Text], Type],
456+
old_protocols: Dict[Tuple[Text, Text], Type],
457+
mode: LoadingMode,
458+
db_name: str,
459+
database_yml: str,
460+
):
449461
"""Merge new and old protocols inplace into the passed new_protocol.
450462
451463
Warning, merging order might be counterintuitive : "KEEP" strategy keeps element from the OLD protocol
@@ -471,7 +483,6 @@ def _merge_protocols_inplace(
471483

472484
# for all previously defined protocol (in old_protocols)
473485
for p_id, old_p in old_protocols.items():
474-
475486
# if this protocol is redefined
476487
if p_id in new_protocols:
477488
t_name, p_name = p_id
@@ -480,13 +491,16 @@ def _merge_protocols_inplace(
480491
# raise an error
481492
if mode == LoadingMode.ERROR:
482493
raise RuntimeError(
483-
f"Cannot load {realname} protocol from '{database_yml}' as it already exists.")
494+
f"Cannot load {realname} protocol from '{database_yml}' as it already exists."
495+
)
484496

485497
# keep the new protocol
486498
elif mode == LoadingMode.OVERRIDE:
487-
warnings.warn(f"Replacing existing {realname} protocol by the one defined in '{database_yml}'.")
499+
warnings.warn(
500+
f"Replacing existing {realname} protocol by the one defined in '{database_yml}'."
501+
)
488502
pass
489-
503+
490504
# keep the old protocol
491505
elif mode == LoadingMode.KEEP:
492506
warnings.warn(
@@ -498,9 +512,10 @@ def _merge_protocols_inplace(
498512
else:
499513
new_protocols[p_id] = old_p
500514

515+
501516
# initialize the registry singleton
502517
registry = Registry()
503518

504519
# load all database yaml files found at startup
505520
for yml in _find_default_ymls():
506-
registry.load_database(yml)
521+
registry.load_database(yml)

0 commit comments

Comments
 (0)