Skip to content

Commit 14b4026

Browse files
mjwillsoncopybara-github
authored andcommitted
In a fdl.Config for a dataclass, when accessing a field that uses a default_factory, populate a child Config (if not already present) wrapping the default_factory call, rather than raising a "Can't get default value for dataclass field ... since it uses a default_factory" error.
This allows easy overriding of properties of child dataclasses, without requiring extra boilerplate in all parent configs to explicitly configure the defaults for child dataclasses. PiperOrigin-RevId: 737952148
1 parent 8536ef2 commit 14b4026

File tree

2 files changed

+96
-26
lines changed

2 files changed

+96
-26
lines changed

fiddle/_src/config.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import dataclasses
2424
import functools
2525
import types
26-
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
26+
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union, cast
2727

2828
from fiddle._src import daglish
2929
from fiddle._src import history
@@ -341,14 +341,42 @@ def __getattr__(self, name: str):
341341

342342
if value is not _UNSET_SENTINEL:
343343
return value
344-
if dataclasses.is_dataclass(
345-
self.__fn_or_cls__
346-
) and _field_uses_default_factory(self.__fn_or_cls__, name):
347-
raise ValueError(
348-
"Can't get default value for dataclass field "
349-
+ f'{self.__fn_or_cls__.__qualname__}.{name} '
350-
+ 'since it uses a default_factory.'
351-
)
344+
if dataclasses.is_dataclass(self.__fn_or_cls__) and (
345+
default_factory := _field_default_factory(self.__fn_or_cls__, name)
346+
):
347+
if hasattr(default_factory, 'as_buildable'):
348+
# The default_factory is a function wrapped by auto_config. In this case
349+
# it would appear reasonable to populate a child config by setting:
350+
# self.__arguments__[name] = default_factory.as_buildable()
351+
# However paxml.tools.fiddle / praxis.pax_fiddle unfortunately implement
352+
# some tricky magic of their own relating to auto_config and dataclass
353+
# default_factories, which would appear to be broken (or at least
354+
# interfered with) by any support for this here. It is likely a fairly
355+
# rare use case anyway outside of paxml, so we are choosing not to
356+
# support it here.
357+
raise ValueError(
358+
"We don't currently support exposing a sub-config to build the "
359+
f"default value for an auto_config'd default_factory (field {name} "
360+
f'of dataclass {self.__fn_or_cls__}.'
361+
)
362+
elif _is_resolvable(default_factory):
363+
self.__arguments__[name] = Config(default_factory)
364+
elif isinstance(default_factory, functools.partial) and _is_resolvable(
365+
default_factory.func
366+
):
367+
self.__arguments__[name] = Config(
368+
default_factory.func,
369+
*default_factory.args,
370+
**default_factory.keywords,
371+
)
372+
else:
373+
raise ValueError(
374+
"Can't expose a sub-config to build default value for field "
375+
f'{name} of dataclass {self.__fn_or_cls__} since it uses an '
376+
'anonymous default_factory.'
377+
)
378+
return self.__arguments__[name]
379+
352380
if param is not None and param.default is not param.empty:
353381
return param.default
354382
msg = f"No parameter '{name}' has been set on {self!r}."
@@ -848,12 +876,27 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T:
848876
return self.__fn_or_cls__(tags=self.tags, *args, **kwargs)
849877

850878

851-
def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
852-
"""Returns true if <dataclass_type>.<field_name> uses a default_factory."""
879+
def _field_default_factory(
880+
dataclass_type: Type[Any], field_name: str
881+
) -> Callable[[], Any] | None:
882+
"""Returns the default_factory of <dataclass_type>.<field_name> if present."""
853883
for field in dataclasses.fields(dataclass_type):
854-
if field.name == field_name:
855-
return field.default_factory != dataclasses.MISSING
856-
return False
884+
if (
885+
field.name == field_name
886+
and field.default_factory != dataclasses.MISSING
887+
):
888+
return cast(Callable[[], Any], field.default_factory)
889+
return None
890+
891+
892+
def _is_resolvable(value: Any) -> bool:
893+
return (
894+
hasattr(value, '__module__')
895+
and hasattr(value, '__qualname__')
896+
and
897+
# Rules out anonymous objects like <lambda>, foo.<locals>.bar etc:
898+
'<' not in value.__qualname__
899+
)
857900

858901

859902
BuildableT = TypeVar('BuildableT', bound=Buildable)

fiddle/_src/config_test.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import copy
1919
import dataclasses
20+
import functools
2021
import pickle
2122
import sys
2223
import threading
@@ -111,6 +112,13 @@ class DataclassParent:
111112
child: DataclassChild = dataclasses.field(default_factory=DataclassChild)
112113

113114

115+
@dataclasses.dataclass
116+
class DataclassParentPartialDefaultFactoryChild:
117+
child: DataclassChild = dataclasses.field(
118+
default_factory=functools.partial(DataclassChild, x=1)
119+
)
120+
121+
114122
def raise_error():
115123
raise ValueError('My fancy exception')
116124

@@ -1042,21 +1050,40 @@ def test_copy_constructor_with_updates_errors(self):
10421050
with self.assertRaises(ValueError):
10431051
fdl.Partial(cfg1, 5, a='a', b='b')
10441052

1045-
def test_dataclass_default_factory(self):
1053+
def test_dataclass_default_factory_can_read_default(self):
1054+
cfg = fdl.Config(DataclassParent)
1055+
child_config = cfg.child
1056+
self.assertIsInstance(child_config, fdl.Config)
1057+
self.assertEqual(child_config.__fn_or_cls__, DataclassChild)
1058+
self.assertEqual(fdl.build(cfg), DataclassParent(child=DataclassChild(x=0)))
1059+
1060+
def test_dataclass_partial_default_factory_can_read_default(self):
1061+
cfg = fdl.Config(DataclassParentPartialDefaultFactoryChild)
1062+
child_config = cfg.child
1063+
self.assertIsInstance(child_config, fdl.Config)
1064+
self.assertEqual(child_config.__fn_or_cls__, DataclassChild)
1065+
self.assertEqual(child_config.x, 1)
1066+
self.assertEqual(
1067+
fdl.build(cfg),
1068+
DataclassParentPartialDefaultFactoryChild(child=DataclassChild(x=1)),
1069+
)
10461070

1071+
def test_dataclass_default_factory_overriding_child_config(self):
10471072
cfg = fdl.Config(DataclassParent)
1073+
cfg.child.x = 5
1074+
self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(x=5)))
1075+
1076+
def test_dataclass_partial_default_factory_overriding_child_config(self):
1077+
cfg = fdl.Config(DataclassParentPartialDefaultFactoryChild)
1078+
cfg.child.x = 5
1079+
self.assertEqual(
1080+
fdl.build(cfg),
1081+
DataclassParentPartialDefaultFactoryChild(DataclassChild(x=5)),
1082+
)
10481083

1049-
with self.subTest('read_default_is_error'):
1050-
expected_error = (
1051-
r"Can't get default value for dataclass field DataclassParent\.child "
1052-
r'since it uses a default_factory\.')
1053-
with self.assertRaisesRegex(ValueError, expected_error):
1054-
cfg.child.x = 5
1055-
1056-
with self.subTest('read_ok_after_override'):
1057-
cfg.child = fdl.Config(DataclassChild) # override default w/ a value
1058-
cfg.child.x = 5 # now it's ok to configure child.
1059-
self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(5)))
1084+
def test_dataclass_default_factory_not_used_when_child_config_given(self):
1085+
cfg = fdl.Config(DataclassParent, child=fdl.Config(DataclassChild, x=1))
1086+
self.assertEqual(cfg.child.x, 1)
10601087

10611088
def test_unbound_method(self):
10621089
sample = fdl.Config(SampleClass, 0, 1)

0 commit comments

Comments
 (0)