Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions fiddle/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import dataclasses
import functools
import types
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union, cast

from fiddle._src import daglish
from fiddle._src import history
Expand Down Expand Up @@ -341,14 +341,42 @@ def __getattr__(self, name: str):

if value is not _UNSET_SENTINEL:
return value
if dataclasses.is_dataclass(
self.__fn_or_cls__
) and _field_uses_default_factory(self.__fn_or_cls__, name):
raise ValueError(
"Can't get default value for dataclass field "
+ f'{self.__fn_or_cls__.__qualname__}.{name} '
+ 'since it uses a default_factory.'
)
if dataclasses.is_dataclass(self.__fn_or_cls__) and (
default_factory := _field_default_factory(self.__fn_or_cls__, name)
):
if hasattr(default_factory, 'as_buildable'):
# The default_factory is a function wrapped by auto_config. In this case
# it would appear reasonable to populate a child config by setting:
# self.__arguments__[name] = default_factory.as_buildable()
# However paxml.tools.fiddle / praxis.pax_fiddle unfortunately implement
# some tricky magic of their own relating to auto_config and dataclass
# default_factories, which would appear to be broken (or at least
# interfered with) by any support for this here. It is likely a fairly
# rare use case anyway outside of paxml, so we are choosing not to
# support it here.
raise ValueError(
"We don't currently support exposing a sub-config to build the "
f"default value for an auto_config'd default_factory (field {name} "
f'of dataclass {self.__fn_or_cls__}.'
)
elif _is_resolvable(default_factory):
self.__arguments__[name] = Config(default_factory)
elif isinstance(default_factory, functools.partial) and _is_resolvable(
default_factory.func
):
self.__arguments__[name] = Config(
default_factory.func,
*default_factory.args,
**default_factory.keywords,
)
else:
raise ValueError(
"Can't expose a sub-config to build default value for field "
f'{name} of dataclass {self.__fn_or_cls__} since it uses an '
'anonymous default_factory.'
)
return self.__arguments__[name]

if param is not None and param.default is not param.empty:
return param.default
msg = f"No parameter '{name}' has been set on {self!r}."
Expand Down Expand Up @@ -848,12 +876,27 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T:
return self.__fn_or_cls__(tags=self.tags, *args, **kwargs)


def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
"""Returns true if <dataclass_type>.<field_name> uses a default_factory."""
def _field_default_factory(
dataclass_type: Type[Any], field_name: str
) -> Callable[[], Any] | None:
"""Returns the default_factory of <dataclass_type>.<field_name> if present."""
for field in dataclasses.fields(dataclass_type):
if field.name == field_name:
return field.default_factory != dataclasses.MISSING
return False
if (
field.name == field_name
and field.default_factory != dataclasses.MISSING
):
return cast(Callable[[], Any], field.default_factory)
return None


def _is_resolvable(value: Any) -> bool:
return (
hasattr(value, '__module__')
and hasattr(value, '__qualname__')
and
# Rules out anonymous objects like <lambda>, foo.<locals>.bar etc:
'<' not in value.__qualname__
)


BuildableT = TypeVar('BuildableT', bound=Buildable)
Expand Down
51 changes: 39 additions & 12 deletions fiddle/_src/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import copy
import dataclasses
import functools
import pickle
import sys
import threading
Expand Down Expand Up @@ -111,6 +112,13 @@ class DataclassParent:
child: DataclassChild = dataclasses.field(default_factory=DataclassChild)


@dataclasses.dataclass
class DataclassParentPartialDefaultFactoryChild:
child: DataclassChild = dataclasses.field(
default_factory=functools.partial(DataclassChild, x=1)
)


def raise_error():
raise ValueError('My fancy exception')

Expand Down Expand Up @@ -1042,21 +1050,40 @@ def test_copy_constructor_with_updates_errors(self):
with self.assertRaises(ValueError):
fdl.Partial(cfg1, 5, a='a', b='b')

def test_dataclass_default_factory(self):
def test_dataclass_default_factory_can_read_default(self):
cfg = fdl.Config(DataclassParent)
child_config = cfg.child
self.assertIsInstance(child_config, fdl.Config)
self.assertEqual(child_config.__fn_or_cls__, DataclassChild)
self.assertEqual(fdl.build(cfg), DataclassParent(child=DataclassChild(x=0)))

def test_dataclass_partial_default_factory_can_read_default(self):
cfg = fdl.Config(DataclassParentPartialDefaultFactoryChild)
child_config = cfg.child
self.assertIsInstance(child_config, fdl.Config)
self.assertEqual(child_config.__fn_or_cls__, DataclassChild)
self.assertEqual(child_config.x, 1)
self.assertEqual(
fdl.build(cfg),
DataclassParentPartialDefaultFactoryChild(child=DataclassChild(x=1)),
)

def test_dataclass_default_factory_overriding_child_config(self):
cfg = fdl.Config(DataclassParent)
cfg.child.x = 5
self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(x=5)))

def test_dataclass_partial_default_factory_overriding_child_config(self):
cfg = fdl.Config(DataclassParentPartialDefaultFactoryChild)
cfg.child.x = 5
self.assertEqual(
fdl.build(cfg),
DataclassParentPartialDefaultFactoryChild(DataclassChild(x=5)),
)

with self.subTest('read_default_is_error'):
expected_error = (
r"Can't get default value for dataclass field DataclassParent\.child "
r'since it uses a default_factory\.')
with self.assertRaisesRegex(ValueError, expected_error):
cfg.child.x = 5

with self.subTest('read_ok_after_override'):
cfg.child = fdl.Config(DataclassChild) # override default w/ a value
cfg.child.x = 5 # now it's ok to configure child.
self.assertEqual(fdl.build(cfg), DataclassParent(DataclassChild(5)))
def test_dataclass_default_factory_not_used_when_child_config_given(self):
cfg = fdl.Config(DataclassParent, child=fdl.Config(DataclassChild, x=1))
self.assertEqual(cfg.child.x, 1)

def test_unbound_method(self):
sample = fdl.Config(SampleClass, 0, 1)
Expand Down