diff --git a/fiddle/_src/config.py b/fiddle/_src/config.py index c82a61d7..9e9c7732 100644 --- a/fiddle/_src/config.py +++ b/fiddle/_src/config.py @@ -23,7 +23,7 @@ import dataclasses import functools import types -from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union, cast from fiddle._src import daglish from fiddle._src import history @@ -341,14 +341,42 @@ def __getattr__(self, name: str): if value is not _UNSET_SENTINEL: return value - if dataclasses.is_dataclass( - self.__fn_or_cls__ - ) and _field_uses_default_factory(self.__fn_or_cls__, name): - raise ValueError( - "Can't get default value for dataclass field " - + f'{self.__fn_or_cls__.__qualname__}.{name} ' - + 'since it uses a default_factory.' - ) + if dataclasses.is_dataclass(self.__fn_or_cls__) and ( + default_factory := _field_default_factory(self.__fn_or_cls__, name) + ): + if hasattr(default_factory, 'as_buildable'): + # The default_factory is a function wrapped by auto_config. In this case + # it would appear reasonable to populate a child config by setting: + # self.__arguments__[name] = default_factory.as_buildable() + # However paxml.tools.fiddle / praxis.pax_fiddle unfortunately implement + # some tricky magic of their own relating to auto_config and dataclass + # default_factories, which would appear to be broken (or at least + # interfered with) by any support for this here. It is likely a fairly + # rare use case anyway outside of paxml, so we are choosing not to + # support it here. + raise ValueError( + "We don't currently support exposing a sub-config to build the " + f"default value for an auto_config'd default_factory (field {name} " + f'of dataclass {self.__fn_or_cls__}.' + ) + elif _is_resolvable(default_factory): + self.__arguments__[name] = Config(default_factory) + elif isinstance(default_factory, functools.partial) and _is_resolvable( + default_factory.func + ): + self.__arguments__[name] = Config( + default_factory.func, + *default_factory.args, + **default_factory.keywords, + ) + else: + raise ValueError( + "Can't expose a sub-config to build default value for field " + f'{name} of dataclass {self.__fn_or_cls__} since it uses an ' + 'anonymous default_factory.' + ) + return self.__arguments__[name] + if param is not None and param.default is not param.empty: return param.default msg = f"No parameter '{name}' has been set on {self!r}." @@ -848,12 +876,27 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T: return self.__fn_or_cls__(tags=self.tags, *args, **kwargs) -def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str): - """Returns true if . uses a default_factory.""" +def _field_default_factory( + dataclass_type: Type[Any], field_name: str +) -> Callable[[], Any] | None: + """Returns the default_factory of . if present.""" for field in dataclasses.fields(dataclass_type): - if field.name == field_name: - return field.default_factory != dataclasses.MISSING - return False + if ( + field.name == field_name + and field.default_factory != dataclasses.MISSING + ): + return cast(Callable[[], Any], field.default_factory) + return None + + +def _is_resolvable(value: Any) -> bool: + return ( + hasattr(value, '__module__') + and hasattr(value, '__qualname__') + and + # Rules out anonymous objects like , foo..bar etc: + '<' not in value.__qualname__ + ) BuildableT = TypeVar('BuildableT', bound=Buildable) diff --git a/fiddle/_src/config_test.py b/fiddle/_src/config_test.py index 7e123cbf..4806c8b6 100644 --- a/fiddle/_src/config_test.py +++ b/fiddle/_src/config_test.py @@ -17,6 +17,7 @@ import copy import dataclasses +import functools import pickle import sys import threading @@ -111,6 +112,13 @@ class DataclassParent: child: DataclassChild = dataclasses.field(default_factory=DataclassChild) +@dataclasses.dataclass +class DataclassParentPartialDefaultFactoryChild: + child: DataclassChild = dataclasses.field( + default_factory=functools.partial(DataclassChild, x=1) + ) + + def raise_error(): raise ValueError('My fancy exception') @@ -1042,21 +1050,40 @@ def test_copy_constructor_with_updates_errors(self): with self.assertRaises(ValueError): fdl.Partial(cfg1, 5, a='a', b='b') - def test_dataclass_default_factory(self): + def test_dataclass_default_factory_can_read_default(self): + cfg = fdl.Config(DataclassParent) + child_config = cfg.child + self.assertIsInstance(child_config, fdl.Config) + self.assertEqual(child_config.__fn_or_cls__, DataclassChild) + self.assertEqual(fdl.build(cfg), DataclassParent(child=DataclassChild(x=0))) + + def test_dataclass_partial_default_factory_can_read_default(self): + cfg = fdl.Config(DataclassParentPartialDefaultFactoryChild) + child_config = cfg.child + self.assertIsInstance(child_config, fdl.Config) + self.assertEqual(child_config.__fn_or_cls__, DataclassChild) + self.assertEqual(child_config.x, 1) + self.assertEqual( + fdl.build(cfg), + DataclassParentPartialDefaultFactoryChild(child=DataclassChild(x=1)), + ) + def test_dataclass_default_factory_overriding_child_config(self): cfg = fdl.Config(DataclassParent) + cfg.child.x = 5 + self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(x=5))) + + def test_dataclass_partial_default_factory_overriding_child_config(self): + cfg = fdl.Config(DataclassParentPartialDefaultFactoryChild) + cfg.child.x = 5 + self.assertEqual( + fdl.build(cfg), + DataclassParentPartialDefaultFactoryChild(DataclassChild(x=5)), + ) - with self.subTest('read_default_is_error'): - expected_error = ( - r"Can't get default value for dataclass field DataclassParent\.child " - r'since it uses a default_factory\.') - with self.assertRaisesRegex(ValueError, expected_error): - cfg.child.x = 5 - - with self.subTest('read_ok_after_override'): - cfg.child = fdl.Config(DataclassChild) # override default w/ a value - cfg.child.x = 5 # now it's ok to configure child. - self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(5))) + def test_dataclass_default_factory_not_used_when_child_config_given(self): + cfg = fdl.Config(DataclassParent, child=fdl.Config(DataclassChild, x=1)) + self.assertEqual(cfg.child.x, 1) def test_unbound_method(self): sample = fdl.Config(SampleClass, 0, 1)