3939from .database import Database
4040import yaml
4141
42+
4243# controls what to do in case of protocol name conflict
4344class LoadingMode (Enum ):
4445 OVERRIDE = 0 # override existing protocol
45- KEEP = 1 # keep existing protocol
46- ERROR = 2 # raise an error
46+ KEEP = 1 # keep existing protocol
47+ ERROR = 2 # raise an error
48+
4749
4850# To ease the understanding of future me, all comments inside Registry codebase
4951# assume the existence of the following database.yml files.
@@ -52,7 +54,7 @@ class LoadingMode(Enum):
5254# Content of /path/to/first/database.yml
5355# ======================================
5456# Databases:
55- # DatabaseA:
57+ # DatabaseA:
5658# - relative/path/A/trn/{uri}.wav
5759# - relative/path/A/dev/{uri}.wav
5860# - relative/path/A/tst/{uri}.wav
@@ -95,7 +97,7 @@ class LoadingMode(Enum):
9597# DatabaseC:
9698# SpeakerDiarization:
9799# Protocol:
98- # ...
100+ # ...
99101
100102
101103class Registry :
@@ -109,13 +111,12 @@ class Registry:
109111 """
110112
111113 def __init__ (self ) -> None :
112-
113114 # Mapping of database.yml paths to their config in a dictionary
114115 # Example after loading both database.yml:
115116 # {"/path/to/first/database.yml": {
116117 # "Databases":{
117118 # "DatabaseA": ["relative/path/A/trn/{uri}.wav", "relative/path/A/dev/{uri}.wav", relative/path/A/tst/{uri}.wav]
118- # "DatabaseB": "/absolute/path/B/{uri}.wav"
119+ # "DatabaseB": "/absolute/path/B/{uri}.wav"
119120 # },
120121 # "Protocols":{
121122 # "DatabaseA":{
@@ -134,7 +135,7 @@ def __init__(self) -> None:
134135 # "/path/to/second/database.yml": {
135136 # "Databases":{
136137 # "DatabaseC": /absolute/path/C/{uri}.wav
137- # "DatabaseB": "/absolute/path/B/{uri}.wav"
138+ # "DatabaseB": "/absolute/path/B/{uri}.wav"
138139 # },
139140 # "Protocols":{
140141 # "DatabaseB":{"SpeakerDiarization": {"Protocol": {...}}},
@@ -144,7 +145,6 @@ def __init__(self) -> None:
144145 # }
145146 self .configs : Dict [Path , Dict ] = dict ()
146147
147-
148148 # Content of the "Database" root item (= where to find file content)
149149 # Example after loading both database.yml:
150150 # {
@@ -176,7 +176,7 @@ def load_database(
176176 Parameters
177177 ----------
178178 path : str or Path
179- Path to YAML configuration file.
179+ Path to YAML configuration file.
180180 mode : LoadingMode, optional
181181 Controls how to handle conflicts in protocol names.
182182 Defaults to overriding the existing protocol.
@@ -210,7 +210,7 @@ def _load_database_helper(
210210 # make path absolute
211211 database_yml = Path (database_yml ).expanduser ().resolve ()
212212
213- # stop here if configuration file is already being loaded
213+ # stop here if configuration file is already being loaded
214214 # (possibly because of circular requirements)
215215 if database_yml in loading :
216216 return
@@ -221,9 +221,9 @@ def _load_database_helper(
221221 # load configuration
222222 with open (database_yml , "r" ) as f :
223223 config = yaml .load (f , Loader = yaml .SafeLoader )
224-
224+
225225 # load every requirement
226- requirements = config .pop (' Requirements' , list ())
226+ requirements = config .pop (" Requirements" , list ())
227227 if not isinstance (requirements , list ):
228228 requirements = [requirements ]
229229 for requirement_yaml in requirements :
@@ -244,9 +244,7 @@ def _load_database_helper(
244244
245245 # load protocols of each database
246246 for db_name , db_entries in protocols .items ():
247- self ._load_protocols (
248- db_name , db_entries , database_yml , mode = mode
249- )
247+ self ._load_protocols (db_name , db_entries , database_yml , mode = mode )
250248
251249 # process "Databases" section
252250 databases = config .get ("Databases" , dict ())
@@ -265,7 +263,6 @@ def _load_database_helper(
265263 # save configuration for later reloading of meta-protocols
266264 self .configs [database_yml ] = config
267265
268-
269266 def get_database (self , database_name , ** kwargs ) -> Database :
270267 """Get database by name
271268
@@ -284,7 +281,6 @@ def get_database(self, database_name, **kwargs) -> Database:
284281 database = self .databases [database_name ]
285282
286283 except KeyError :
287-
288284 if database_name == "X" :
289285 msg = (
290286 "Could not find any meta-protocol. Please refer to "
@@ -302,7 +298,9 @@ def get_database(self, database_name, **kwargs) -> Database:
302298
303299 return database (** kwargs )
304300
305- def get_protocol (self , name , preprocessors : Optional [Preprocessors ] = None ) -> Protocol :
301+ def get_protocol (
302+ self , name , preprocessors : Optional [Preprocessors ] = None
303+ ) -> Protocol :
306304 """Get protocol by full name
307305
308306 Parameters
@@ -329,6 +327,14 @@ def get_protocol(self, name, preprocessors: Optional[Preprocessors] = None) -> P
329327 protocol .name = name
330328 return protocol
331329
330+ # iterate over all protocols by name
331+ def __iter__ (self ):
332+ for database_name in self .databases :
333+ database = self .get_database (database_name )
334+ for task_name in database .get_tasks ():
335+ for protocol_name in database .get_protocols (task_name ):
336+ yield f"{ database_name } .{ task_name } .{ protocol_name } "
337+
332338 def _load_protocols (
333339 self ,
334340 db_name ,
@@ -367,7 +373,9 @@ def _load_protocols(
367373 # If needed, merge old protocols dict with the new one (according to current override rules)
368374 if db_name in self .databases :
369375 old_protocols = self .databases [db_name ]._protocols
370- _merge_protocols_inplace (protocols , old_protocols , mode , db_name , database_yml )
376+ _merge_protocols_inplace (
377+ protocols , old_protocols , mode , db_name , database_yml
378+ )
371379
372380 # create database class on-the-fly
373381 protocol_list = [
@@ -389,13 +397,14 @@ def _reload_meta_protocols(self):
389397 for db_yml , config in self .configs .items ():
390398 databases = config .get ("Protocols" , dict ())
391399 if "X" in databases :
392- self ._load_protocols ("X" , databases ["X" ], db_yml , mode = LoadingMode .OVERRIDE )
393-
400+ self ._load_protocols (
401+ "X" , databases ["X" ], db_yml , mode = LoadingMode .OVERRIDE
402+ )
394403
395404
396405def _env_config_paths () -> List [Path ]:
397406 """Parse PYANNOTE_DATABASE_CONFIG environment variable
398-
407+
399408 PYANNOTE_DATABASE_CONFIG may contain multiple paths separation by ";".
400409
401410 Returns
@@ -413,10 +422,11 @@ def _env_config_paths() -> List[Path]:
413422 paths .append (path )
414423 return paths
415424
425+
416426def _find_default_ymls () -> List [Path ]:
417427 """Get paths to default YAML configuration files
418428
419- * $HOME/.pyannote/database.yml
429+ * $HOME/.pyannote/database.yml
420430 * $CWD/database.yml
421431 * PYANNOTE_DATABASE_CONFIG environment variable
422432
@@ -431,21 +441,23 @@ def _find_default_ymls() -> List[Path]:
431441 home_db_yml = Path ("~/.pyannote/database.yml" ).expanduser ()
432442 if home_db_yml .is_file ():
433443 paths .append (home_db_yml )
434-
444+
435445 cwd_db_yml = Path .cwd () / "database.yml"
436446 if cwd_db_yml .is_file ():
437447 paths .append (cwd_db_yml )
438-
448+
439449 paths += _env_config_paths ()
440450
441451 return paths
442452
453+
443454def _merge_protocols_inplace (
444- new_protocols : Dict [Tuple [Text , Text ], Type ],
445- old_protocols : Dict [Tuple [Text , Text ], Type ],
446- mode : LoadingMode ,
447- db_name : str ,
448- database_yml : str ):
455+ new_protocols : Dict [Tuple [Text , Text ], Type ],
456+ old_protocols : Dict [Tuple [Text , Text ], Type ],
457+ mode : LoadingMode ,
458+ db_name : str ,
459+ database_yml : str ,
460+ ):
449461 """Merge new and old protocols inplace into the passed new_protocol.
450462
451463 Warning, merging order might be counterintuitive : "KEEP" strategy keeps element from the OLD protocol
@@ -471,7 +483,6 @@ def _merge_protocols_inplace(
471483
472484 # for all previously defined protocol (in old_protocols)
473485 for p_id , old_p in old_protocols .items ():
474-
475486 # if this protocol is redefined
476487 if p_id in new_protocols :
477488 t_name , p_name = p_id
@@ -480,13 +491,16 @@ def _merge_protocols_inplace(
480491 # raise an error
481492 if mode == LoadingMode .ERROR :
482493 raise RuntimeError (
483- f"Cannot load { realname } protocol from '{ database_yml } ' as it already exists." )
494+ f"Cannot load { realname } protocol from '{ database_yml } ' as it already exists."
495+ )
484496
485497 # keep the new protocol
486498 elif mode == LoadingMode .OVERRIDE :
487- warnings .warn (f"Replacing existing { realname } protocol by the one defined in '{ database_yml } '." )
499+ warnings .warn (
500+ f"Replacing existing { realname } protocol by the one defined in '{ database_yml } '."
501+ )
488502 pass
489-
503+
490504 # keep the old protocol
491505 elif mode == LoadingMode .KEEP :
492506 warnings .warn (
@@ -498,9 +512,10 @@ def _merge_protocols_inplace(
498512 else :
499513 new_protocols [p_id ] = old_p
500514
515+
501516# initialize the registry singleton
502517registry = Registry ()
503518
504519# load all database yaml files found at startup
505520for yml in _find_default_ymls ():
506- registry .load_database (yml )
521+ registry .load_database (yml )
0 commit comments