@@ -552,6 +552,28 @@ def _make_partial(partial_cls, buildable_or_callable, *args, **kwargs):
552552 return partial_cls (buildable_or_callable , * args , ** kwargs )
553553
554554
555+ def _override_values_recursively (
556+ base : config .Buildable , overrides : config .Buildable
557+ ) -> None :
558+ """Recursively replaces fields in base with values from overrides."""
559+ for field_name in dir (overrides ):
560+ field = getattr (overrides , field_name , config .NO_VALUE )
561+ if field != config .NO_VALUE :
562+ if isinstance (field , config .Buildable ):
563+ _override_values_recursively (getattr (base , field_name ), field )
564+ else :
565+ setattr (base , field_name , field )
566+
567+
568+ def override_values (
569+ base : config .Buildable , overrides : config .Buildable
570+ ) -> config .Buildable :
571+ """Returns a copy of base with any values present in overrides overridden."""
572+ base = base .__deepcopy__ (memo = {})
573+ _override_values_recursively (base , overrides )
574+ return base
575+
576+
555577def exempt (fn_or_cls : Callable [..., Any ]) -> Callable [..., Any ]:
556578 """Wrap a callable so that it's exempted from auto_config.
557579
@@ -599,6 +621,7 @@ def auto_config(
599621 experimental_exemption_policy : Optional [auto_config_policy .Policy ] = None ,
600622 experimental_config_types : ConfigTypes = ConfigTypes (),
601623 experimental_result_must_contain_buildable : bool = True ,
624+ base_config : Optional [AutoConfig ] = None ,
602625) -> Any : # TODO(b/272377821): More precise return type.
603626 """Rewrites the given function to make it generate a ``Config``.
604627
@@ -693,6 +716,9 @@ def build_model():
693716 experimental_result_must_contain_buildable: If true, then raise an error if
694717 `fn.as_buildable` returns a result that does not contain any `Buildable`
695718 values -- e.g., if it returns an empty dict.
719+ base_config: Ff given, would be used as a default values. This allows to
720+ have common settings defined once while defing multiple slightly different
721+ configurations.
696722
697723 Returns:
698724 A wrapped version of ``fn``, but with an additional ``as_buildable``
@@ -857,13 +883,25 @@ def make_auto_config(fn):
857883 auto_config_fn .__defaults__ = fn .__defaults__
858884 auto_config_fn .__kwdefaults__ = fn .__kwdefaults__
859885
886+ if base_config is not None :
887+
888+ @functools .wraps (auto_config_fn )
889+ def auto_config_fn_with_base (* args , ** kwargs ):
890+ return override_values (
891+ base_config .as_buildable (* args , ** kwargs ),
892+ auto_config_fn (* args , ** kwargs ), # pylint: disable=not-callable
893+ )
894+
895+ else :
896+ auto_config_fn_with_base = auto_config_fn
897+
860898 # Finally we wrap the rewritten function to perform additional error
861899 # checking and enforce that the output contains a `fdl.Buildable`.
862900 if experimental_result_must_contain_buildable :
863901
864- @functools .wraps (auto_config_fn )
902+ @functools .wraps (auto_config_fn_with_base )
865903 def as_buildable (* args , ** kwargs ):
866- output = auto_config_fn (* args , ** kwargs ) # pylint: disable=not-callable
904+ output = auto_config_fn_with_base (* args , ** kwargs ) # pylint: disable=not-callable
867905 if not _contains_buildable (output ):
868906 raise TypeError (
869907 f'The `auto_config` rewritten version of `{ fn .__qualname__ } ` '
@@ -875,7 +913,7 @@ def as_buildable(*args, **kwargs):
875913 return output
876914
877915 else :
878- as_buildable = auto_config_fn
916+ as_buildable = auto_config_fn_with_base
879917
880918 if method_type :
881919 fn = method_type (fn )
0 commit comments