Skip to content

Commit 06c50eb

Browse files
Yi Wangchanglan
authored andcommitted
Move get_named_trainer_config from experiments to common
GitOrigin-RevId: 3b4449a5a642b42aa32ffea2d4daa85506091b0c
1 parent db8a364 commit 06c50eb

File tree

5 files changed

+81
-87
lines changed

5 files changed

+81
-87
lines changed

axlearn/common/config.py

Lines changed: 78 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ class Config(ConfigBase):
6060
config fields with mutable values, including config values.
6161
"""
6262

63-
# Note: config.py should not depend on jax, torch, or tf.
6463
import copy
6564
import dataclasses
6665
import enum
@@ -83,6 +82,9 @@ class Config(ConfigBase):
8382
# to apply validation on field names and values.
8483
import attr
8584

85+
# Note: config.py should not depend on jax, torch, or tf.
86+
from absl import logging
87+
8688

8789
def is_named_tuple(x: Any):
8890
"""Returns whether an object is an instance of a collections.namedtuple.
@@ -210,6 +212,32 @@ def validate_config_field_name(name: str) -> None:
210212
raise InvalidConfigNameError(f'Invalid config field name "{name}"')
211213

212214

215+
def validate_config_field_value(value: Any) -> None:
216+
"""Validates a config field value.
217+
218+
Validation is handled by validators registered via `register_validator`. `match_fn`s will be
219+
invoked in order of registration, and all matched `validate_fn`s will be invoked.
220+
221+
Args:
222+
value: The value to be validated.
223+
224+
Raises:
225+
InvalidConfigValueError: If no validator matched the given value.
226+
"""
227+
matched = False
228+
for match_fn, validate_fn in _config_field_validators.items():
229+
if match_fn(value):
230+
matched = True
231+
validate_fn(value)
232+
233+
# No validators matched.
234+
if not matched:
235+
raise InvalidConfigValueError(
236+
f'Invalid config value type {type(value)} for value "{value}". '
237+
f"Consider registering a custom validator with `{register_validator.__name__}`."
238+
)
239+
240+
213241
# Validate basic types.
214242
register_validator(
215243
match_fn=lambda v: (
@@ -280,32 +308,6 @@ def _maybe_register_optional_type(module: str, attribute: str):
280308
_maybe_register_optional_type("jax.sharding", "PartitionSpec")
281309

282310

283-
def validate_config_field_value(value: Any) -> None:
284-
"""Validates a config field value.
285-
286-
Validation is handled by validators registered via `register_validator`. `match_fn`s will be
287-
invoked in order of registration, and all matched `validate_fn`s will be invoked.
288-
289-
Args:
290-
value: The value to be validated.
291-
292-
Raises:
293-
InvalidConfigValueError: If no validator matched the given value.
294-
"""
295-
matched = False
296-
for match_fn, validate_fn in _config_field_validators.items():
297-
if match_fn(value):
298-
matched = True
299-
validate_fn(value)
300-
301-
# No validators matched.
302-
if not matched:
303-
raise InvalidConfigValueError(
304-
f'Invalid config value type {type(value)} for value "{value}". '
305-
f"Consider registering a custom validator with `{register_validator.__name__}`."
306-
)
307-
308-
309311
def _validate_and_transform_field(instance, attribute, value):
310312
"""Validates an attribute as a config field.
311313
@@ -1055,3 +1057,52 @@ class ConfigModifier(Configurable):
10551057
def __call__(self, cfg: InstantiableConfig[T]) -> InstantiableConfig[T]:
10561058
"""A function that modifies the input config, should be defined by subclasses."""
10571059
return cfg
1060+
1061+
1062+
def _load_trainer_configs(
1063+
config_module: str, *, optional: bool = False
1064+
) -> dict[str, TrainerConfigFn]:
1065+
try:
1066+
module = importlib.import_module(config_module)
1067+
return module.named_trainer_configs()
1068+
except (ImportError, AttributeError):
1069+
if not optional:
1070+
raise
1071+
logging.warning(
1072+
"Missing dependencies for %s but it's marked optional -- skipping.", config_module
1073+
)
1074+
return {}
1075+
1076+
1077+
def get_named_trainer_config(config_name: str, *, config_module: str) -> TrainerConfigFn:
1078+
"""Looks up TrainerConfigFn by config name.
1079+
1080+
Args:
1081+
config_name: Candidate config name.
1082+
config_module: Config module name.
1083+
1084+
Returns:
1085+
A TrainerConfigFn corresponding to the config name.
1086+
1087+
Raises:
1088+
KeyError: Error containing the message to show to the user.
1089+
"""
1090+
config_map = _load_trainer_configs(config_module)
1091+
if callable(config_map):
1092+
return config_map(config_name)
1093+
1094+
try:
1095+
return config_map[config_name]
1096+
except KeyError as e:
1097+
similar = similar_names(config_name, set(config_map.keys()))
1098+
if similar:
1099+
message = f"Unrecognized config {config_name}; did you mean [{', '.join(similar)}]"
1100+
else:
1101+
message = (
1102+
f"Unrecognized config {config_name} under {config_module}; "
1103+
f"Please make sure that the following conditions are met:\n"
1104+
f" 1. {config_module} can be imported; "
1105+
f" 2. {config_module} defines `named_trainer_configs()`; "
1106+
f" 3. `named_trainer_configs()` returns a dict with '{config_name}' as a key."
1107+
)
1108+
raise KeyError(message) from e

axlearn/common/launch_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@
1111

1212
from axlearn.common import file_system as fs
1313
from axlearn.common import measurement
14-
from axlearn.common.config import TrainerConfigFn
14+
from axlearn.common.config import TrainerConfigFn, get_named_trainer_config
1515
from axlearn.common.trainer import SpmdTrainer, select_mesh_config
1616
from axlearn.common.utils import MeshShape, get_data_dir, infer_mesh_shape
17-
from axlearn.experiments import get_named_trainer_config
1817

1918
# Trainer-specific flags.
2019
flags.DEFINE_string(

axlearn/experiments/__init__.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,3 @@
11
# Copyright © 2023 Apple Inc.
22

33
"""AXLearn experiments."""
4-
5-
from importlib import import_module
6-
7-
from absl import logging
8-
9-
from axlearn.common.config import TrainerConfigFn, similar_names
10-
11-
12-
def _load_trainer_configs(
13-
config_module: str, *, optional: bool = False
14-
) -> dict[str, TrainerConfigFn]:
15-
try:
16-
module = import_module(config_module)
17-
return module.named_trainer_configs()
18-
except (ImportError, AttributeError):
19-
if not optional:
20-
raise
21-
logging.warning(
22-
"Missing dependencies for %s but it's marked optional -- skipping.", config_module
23-
)
24-
return {}
25-
26-
27-
def get_named_trainer_config(config_name: str, *, config_module: str) -> TrainerConfigFn:
28-
"""Looks up TrainerConfigFn by config name.
29-
30-
Args:
31-
config_name: Candidate config name.
32-
config_module: Config module name.
33-
34-
Returns:
35-
A TrainerConfigFn corresponding to the config name.
36-
37-
Raises:
38-
KeyError: Error containing the message to show to the user.
39-
"""
40-
config_map = _load_trainer_configs(config_module)
41-
if callable(config_map):
42-
return config_map(config_name)
43-
44-
try:
45-
return config_map[config_name]
46-
except KeyError as e:
47-
similar = similar_names(config_name, set(config_map.keys()))
48-
if similar:
49-
message = f"Unrecognized config {config_name}; did you mean [{', '.join(similar)}]"
50-
else:
51-
message = (
52-
f"Unrecognized config {config_name} under {config_module}; "
53-
f"Please make sure that the following conditions are met:\n"
54-
f" 1. {config_module} can be imported; "
55-
f" 2. {config_module} defines `named_trainer_configs()`; "
56-
f" 3. `named_trainer_configs()` returns a dict with '{config_name}' as a key."
57-
)
58-
raise KeyError(message) from e

axlearn/experiments/golden_ckpt_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from absl import flags
1414
from absl.testing import absltest
1515

16+
from axlearn.common.config import get_named_trainer_config
1617
from axlearn.common.inference import InferenceRunner
1718
from axlearn.common.summary_writer import NoOpWriter
1819
from axlearn.common.test_utils import TestCase
1920
from axlearn.common.trainer import SpmdTrainer
20-
from axlearn.experiments import get_named_trainer_config
2121

2222
flags.DEFINE_boolean("update_golden_checkpoints", False, "If true, update golden config files.")
2323

axlearn/experiments/run_aot_compilation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@
4040
from jax.experimental.serialize_executable import serialize
4141

4242
from axlearn.common import aot_compilation, compiler_options
43-
from axlearn.common.config import TrainerConfigFn
43+
from axlearn.common.config import TrainerConfigFn, get_named_trainer_config
4444
from axlearn.common.trainer import SpmdTrainer, aot_model_analysis, select_mesh_config
4545
from axlearn.common.utils import set_data_dir
4646
from axlearn.common.utils_spmd import setup
47-
from axlearn.experiments import get_named_trainer_config
4847

4948
flags.DEFINE_string("module", None, "The trainer config module.", required=True)
5049
flags.DEFINE_string("config", None, "The trainer config name.", required=True)

0 commit comments

Comments
 (0)