Skip to content

Commit 6aad4cb

Browse files
mjwillsoncopybara-github
authored andcommitted
When accessing a dataclass field using a default_factory, populate a child Config or Partial wrapping the default_factory call, rather than raising "Can't get default value for dataclass field ... since it uses a default_factory" error.
PiperOrigin-RevId: 737952148
1 parent d45d7ca commit 6aad4cb

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

fiddle/_src/config.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import dataclasses
2424
import functools
2525
import types
26-
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
26+
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union, cast
2727

2828
from fiddle._src import daglish
2929
from fiddle._src import history
@@ -341,14 +341,25 @@ def __getattr__(self, name: str):
341341

342342
if value is not _UNSET_SENTINEL:
343343
return value
344-
if dataclasses.is_dataclass(
345-
self.__fn_or_cls__
346-
) and _field_uses_default_factory(self.__fn_or_cls__, name):
347-
raise ValueError(
348-
"Can't get default value for dataclass field "
349-
+ f'{self.__fn_or_cls__.__qualname__}.{name} '
350-
+ 'since it uses a default_factory.'
351-
)
344+
if dataclasses.is_dataclass(self.__fn_or_cls__) and (
345+
default_factory := _field_default_factory(self.__fn_or_cls__, name)):
346+
if _is_resolvable(default_factory):
347+
self.__arguments__[name] = Config(default_factory)
348+
return self.__arguments__[name]
349+
elif (isinstance(default_factory, functools.partial)
350+
and _is_resolvable(default_factory.func)):
351+
from fiddle._src import partial # pylint: disable=g-import-not-at-top # Avoid cyclic load dependency
352+
self.__arguments__[name] = partial.Partial(
353+
default_factory.func,
354+
*default_factory.args,
355+
**default_factory.keywords)
356+
return self.__arguments__[name]
357+
else:
358+
return ValueError(
359+
"Can't expose a sub-config to build default value of dataclass "
360+
f'field {self.__fn_or_cls__.__qualname__}.{name} since it uses an '
361+
'anonymous default_factory.'
362+
)
352363
if param is not None and param.default is not param.empty:
353364
return param.default
354365
msg = f"No parameter '{name}' has been set on {self!r}."
@@ -848,12 +859,24 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T:
848859
return self.__fn_or_cls__(tags=self.tags, *args, **kwargs)
849860

850861

851-
def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
852-
"""Returns true if <dataclass_type>.<field_name> uses a default_factory."""
862+
def _field_default_factory(
863+
dataclass_type: Type[Any], field_name: str) -> Callable[[], Any] | None:
864+
"""Returns the default_factory of <dataclass_type>.<field_name> if present."""
853865
for field in dataclasses.fields(dataclass_type):
854-
if field.name == field_name:
855-
return field.default_factory != dataclasses.MISSING
856-
return False
866+
if (field.name == field_name and
867+
field.default_factory != dataclasses.MISSING):
868+
return cast(Callable[[], Any], field.default_factory)
869+
return None
870+
871+
872+
def _is_resolvable(value: Any) -> bool:
873+
# TO DO: check we can roundtrip, ideally in a way that guarantees
874+
# compatibility with serialization code.
875+
return (
876+
hasattr(value, '__module__') and
877+
hasattr(value, '__qualname__') and
878+
# Rules out anonymous objects like <lambda>, foo.<locals>.bar etc:
879+
'<' not in value.__qualname__)
857880

858881

859882
BuildableT = TypeVar('BuildableT', bound=Buildable)

fiddle/_src/config_test.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,16 +1046,14 @@ def test_dataclass_default_factory(self):
10461046

10471047
cfg = fdl.Config(DataclassParent)
10481048

1049-
with self.subTest('read_default_is_error'):
1050-
expected_error = (
1051-
r"Can't get default value for dataclass field DataclassParent\.child "
1052-
r'since it uses a default_factory\.')
1053-
with self.assertRaisesRegex(ValueError, expected_error):
1054-
cfg.child.x = 5
1049+
with self.subTest('can_read_default'):
1050+
child_config = cfg.child
1051+
self.assertIsInstance(child_config, fdl.Config)
1052+
self.assertEqual(child_config.__fn_or_cls__, DataclassChild)
1053+
self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(1)))
10551054

10561055
with self.subTest('read_ok_after_override'):
1057-
cfg.child = fdl.Config(DataclassChild) # override default w/ a value
1058-
cfg.child.x = 5 # now it's ok to configure child.
1056+
cfg.child.x = 5
10591057
self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(5)))
10601058

10611059
def test_unbound_method(self):

0 commit comments

Comments
 (0)