@@ -60,7 +60,6 @@ class Config(ConfigBase):
60
60
config fields with mutable values, including config values.
61
61
"""
62
62
63
- # Note: config.py should not depend on jax, torch, or tf.
64
63
import copy
65
64
import dataclasses
66
65
import enum
@@ -83,6 +82,9 @@ class Config(ConfigBase):
83
82
# to apply validation on field names and values.
84
83
import attr
85
84
85
+ # Note: config.py should not depend on jax, torch, or tf.
86
+ from absl import logging
87
+
86
88
87
89
def is_named_tuple (x : Any ):
88
90
"""Returns whether an object is an instance of a collections.namedtuple.
@@ -210,6 +212,32 @@ def validate_config_field_name(name: str) -> None:
210
212
raise InvalidConfigNameError (f'Invalid config field name "{ name } "' )
211
213
212
214
215
+ def validate_config_field_value (value : Any ) -> None :
216
+ """Validates a config field value.
217
+
218
+ Validation is handled by validators registered via `register_validator`. `match_fn`s will be
219
+ invoked in order of registration, and all matched `validate_fn`s will be invoked.
220
+
221
+ Args:
222
+ value: The value to be validated.
223
+
224
+ Raises:
225
+ InvalidConfigValueError: If no validator matched the given value.
226
+ """
227
+ matched = False
228
+ for match_fn , validate_fn in _config_field_validators .items ():
229
+ if match_fn (value ):
230
+ matched = True
231
+ validate_fn (value )
232
+
233
+ # No validators matched.
234
+ if not matched :
235
+ raise InvalidConfigValueError (
236
+ f'Invalid config value type { type (value )} for value "{ value } ". '
237
+ f"Consider registering a custom validator with `{ register_validator .__name__ } `."
238
+ )
239
+
240
+
213
241
# Validate basic types.
214
242
register_validator (
215
243
match_fn = lambda v : (
@@ -280,32 +308,6 @@ def _maybe_register_optional_type(module: str, attribute: str):
280
308
_maybe_register_optional_type ("jax.sharding" , "PartitionSpec" )
281
309
282
310
283
- def validate_config_field_value (value : Any ) -> None :
284
- """Validates a config field value.
285
-
286
- Validation is handled by validators registered via `register_validator`. `match_fn`s will be
287
- invoked in order of registration, and all matched `validate_fn`s will be invoked.
288
-
289
- Args:
290
- value: The value to be validated.
291
-
292
- Raises:
293
- InvalidConfigValueError: If no validator matched the given value.
294
- """
295
- matched = False
296
- for match_fn , validate_fn in _config_field_validators .items ():
297
- if match_fn (value ):
298
- matched = True
299
- validate_fn (value )
300
-
301
- # No validators matched.
302
- if not matched :
303
- raise InvalidConfigValueError (
304
- f'Invalid config value type { type (value )} for value "{ value } ". '
305
- f"Consider registering a custom validator with `{ register_validator .__name__ } `."
306
- )
307
-
308
-
309
311
def _validate_and_transform_field (instance , attribute , value ):
310
312
"""Validates an attribute as a config field.
311
313
@@ -1055,3 +1057,52 @@ class ConfigModifier(Configurable):
1055
1057
def __call__ (self , cfg : InstantiableConfig [T ]) -> InstantiableConfig [T ]:
1056
1058
"""A function that modifies the input config, should be defined by subclasses."""
1057
1059
return cfg
1060
+
1061
+
1062
+ def _load_trainer_configs (
1063
+ config_module : str , * , optional : bool = False
1064
+ ) -> dict [str , TrainerConfigFn ]:
1065
+ try :
1066
+ module = importlib .import_module (config_module )
1067
+ return module .named_trainer_configs ()
1068
+ except (ImportError , AttributeError ):
1069
+ if not optional :
1070
+ raise
1071
+ logging .warning (
1072
+ "Missing dependencies for %s but it's marked optional -- skipping." , config_module
1073
+ )
1074
+ return {}
1075
+
1076
+
1077
+ def get_named_trainer_config (config_name : str , * , config_module : str ) -> TrainerConfigFn :
1078
+ """Looks up TrainerConfigFn by config name.
1079
+
1080
+ Args:
1081
+ config_name: Candidate config name.
1082
+ config_module: Config module name.
1083
+
1084
+ Returns:
1085
+ A TrainerConfigFn corresponding to the config name.
1086
+
1087
+ Raises:
1088
+ KeyError: Error containing the message to show to the user.
1089
+ """
1090
+ config_map = _load_trainer_configs (config_module )
1091
+ if callable (config_map ):
1092
+ return config_map (config_name )
1093
+
1094
+ try :
1095
+ return config_map [config_name ]
1096
+ except KeyError as e :
1097
+ similar = similar_names (config_name , set (config_map .keys ()))
1098
+ if similar :
1099
+ message = f"Unrecognized config { config_name } ; did you mean [{ ', ' .join (similar )} ]"
1100
+ else :
1101
+ message = (
1102
+ f"Unrecognized config { config_name } under { config_module } ; "
1103
+ f"Please make sure that the following conditions are met:\n "
1104
+ f" 1. { config_module } can be imported; "
1105
+ f" 2. { config_module } defines `named_trainer_configs()`; "
1106
+ f" 3. `named_trainer_configs()` returns a dict with '{ config_name } ' as a key."
1107
+ )
1108
+ raise KeyError (message ) from e
0 commit comments