Skip to content

Commit 2947bf2

Browse files
Fiddle-Config Teamcopybara-github
authored andcommitted
Not a public change.
PiperOrigin-RevId: 555220375
1 parent b5abdf1 commit 2947bf2

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

fiddle/_src/absl_flags/sample_test_binary.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def base_experiment() -> fdl.Config:
3535
return fake_encoder_decoder.fixture.as_buildable()
3636

3737

38+
def base_experiment_with_bias() -> fdl.Config:
39+
return fake_encoder_decoder.fixture_with_bias.as_buildable()
40+
41+
3842
def set_dtypes(config, dtype: str):
3943
def traverse(value, state):
4044
if state.current_path and state.current_path[-1] == daglish.Attr("dtype"):

fiddle/_src/experimental/auto_config.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,28 @@ def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs):
552552
return partial_cls(buildable_or_callable, *args, **kwargs)
553553

554554

555+
def _override_values_recursively(
556+
base: config.Buildable, overrides: config.Buildable
557+
) -> None:
558+
"""Recursively replaces fields in base with values from overrides."""
559+
for field_name in dir(overrides):
560+
field = getattr(overrides, field_name, config.NO_VALUE)
561+
if field != config.NO_VALUE:
562+
if isinstance(field, config.Buildable):
563+
_override_values_recursively(getattr(base, field_name), field)
564+
else:
565+
setattr(base, field_name, field)
566+
567+
568+
def override_values(
569+
base: config.Buildable, overrides: config.Buildable
570+
) -> config.Buildable:
571+
"""Returns a copy of base with any values present in overrides overridden."""
572+
base = base.__deepcopy__(memo={})
573+
_override_values_recursively(base, overrides)
574+
return base
575+
576+
555577
def exempt(fn_or_cls: Callable[..., Any]) -> Callable[..., Any]:
556578
"""Wrap a callable so that it's exempted from auto_config.
557579
@@ -599,6 +621,7 @@ def auto_config(
599621
experimental_exemption_policy: Optional[auto_config_policy.Policy] = None,
600622
experimental_config_types: ConfigTypes = ConfigTypes(),
601623
experimental_result_must_contain_buildable: bool = True,
624+
base_config: Optional[AutoConfig] = None,
602625
) -> Any: # TODO(b/272377821): More precise return type.
603626
"""Rewrites the given function to make it generate a ``Config``.
604627
@@ -693,6 +716,9 @@ def build_model():
693716
experimental_result_must_contain_buildable: If true, then raise an error if
694717
`fn.as_buildable` returns a result that does not contain any `Buildable`
695718
values -- e.g., if it returns an empty dict.
719+
base_config: Ff given, would be used as a default values. This allows to
720+
have common settings defined once while defing multiple slightly different
721+
configurations.
696722
697723
Returns:
698724
A wrapped version of ``fn``, but with an additional ``as_buildable``
@@ -857,13 +883,25 @@ def make_auto_config(fn):
857883
auto_config_fn.__defaults__ = fn.__defaults__
858884
auto_config_fn.__kwdefaults__ = fn.__kwdefaults__
859885

886+
if base_config is not None:
887+
888+
@functools.wraps(auto_config_fn)
889+
def auto_config_fn_with_base(*args, **kwargs):
890+
return override_values(
891+
base_config.as_buildable(*args, **kwargs),
892+
auto_config_fn(*args, **kwargs), # pylint: disable=not-callable
893+
)
894+
895+
else:
896+
auto_config_fn_with_base = auto_config_fn
897+
860898
# Finally we wrap the rewritten function to perform additional error
861899
# checking and enforce that the output contains a `fdl.Buildable`.
862900
if experimental_result_must_contain_buildable:
863901

864-
@functools.wraps(auto_config_fn)
902+
@functools.wraps(auto_config_fn_with_base)
865903
def as_buildable(*args, **kwargs):
866-
output = auto_config_fn(*args, **kwargs) # pylint: disable=not-callable
904+
output = auto_config_fn_with_base(*args, **kwargs) # pylint: disable=not-callable
867905
if not _contains_buildable(output):
868906
raise TypeError(
869907
f'The `auto_config` rewritten version of `{fn.__qualname__}` '
@@ -875,7 +913,7 @@ def as_buildable(*args, **kwargs):
875913
return output
876914

877915
else:
878-
as_buildable = auto_config_fn
916+
as_buildable = auto_config_fn_with_base
879917

880918
if method_type:
881919
fn = method_type(fn)

fiddle/_src/testing/example/fake_encoder_decoder.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,16 @@ def fixture():
9999
bias_init),
100100
mlp=Mlp(dtype, False, ["num_heads", "head_dim", "embed"]),
101101
))
102+
103+
104+
@auto_config.auto_config(base_config=fixture)
105+
def fixture_with_bias():
106+
# pylint: disable=no-value-for-parameter
107+
# pytype: disable=missing-parameter
108+
return FakeEncoderDecoder(
109+
encoder=FakeEncoder(
110+
mlp=Mlp(use_bias=True),
111+
),
112+
)
113+
# pytype: enable=missing-parameter
114+
# pylint: enable=no-value-for-parameter

0 commit comments

Comments
 (0)