|
23 | 23 | import dataclasses |
24 | 24 | import functools |
25 | 25 | import types |
26 | | -from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union |
| 26 | +from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union, cast |
27 | 27 |
|
28 | 28 | from fiddle._src import daglish |
29 | 29 | from fiddle._src import history |
@@ -341,14 +341,25 @@ def __getattr__(self, name: str): |
341 | 341 |
|
342 | 342 | if value is not _UNSET_SENTINEL: |
343 | 343 | return value |
344 | | - if dataclasses.is_dataclass( |
345 | | - self.__fn_or_cls__ |
346 | | - ) and _field_uses_default_factory(self.__fn_or_cls__, name): |
347 | | - raise ValueError( |
348 | | - "Can't get default value for dataclass field " |
349 | | - + f'{self.__fn_or_cls__.__qualname__}.{name} ' |
350 | | - + 'since it uses a default_factory.' |
351 | | - ) |
| 344 | + if dataclasses.is_dataclass(self.__fn_or_cls__) and ( |
| 345 | + default_factory := _field_default_factory(self.__fn_or_cls__, name)): |
| 346 | + if _is_resolvable(default_factory): |
| 347 | + self.__arguments__[name] = Config(default_factory) |
| 348 | + return self.__arguments__[name] |
| 349 | + elif (isinstance(default_factory, functools.partial) |
| 350 | + and _is_resolvable(default_factory.func)): |
| 351 | + from fiddle._src import partial # pylint: disable=g-import-not-at-top # Avoid cyclic load dependency |
| 352 | + self.__arguments__[name] = partial.Partial( |
| 353 | + default_factory.func, |
| 354 | + *default_factory.args, |
| 355 | + **default_factory.keywords) |
| 356 | + return self.__arguments__[name] |
| 357 | + else: |
| 358 | + return ValueError( |
| 359 | + "Can't expose a sub-config to build default value of dataclass " |
| 360 | + f'field {self.__fn_or_cls__.__qualname__}.{name} since it uses an ' |
| 361 | + 'anonymous default_factory.' |
| 362 | + ) |
352 | 363 | if param is not None and param.default is not param.empty: |
353 | 364 | return param.default |
354 | 365 | msg = f"No parameter '{name}' has been set on {self!r}." |
@@ -848,12 +859,24 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T: |
848 | 859 | return self.__fn_or_cls__(tags=self.tags, *args, **kwargs) |
849 | 860 |
|
850 | 861 |
|
851 | | -def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str): |
852 | | - """Returns true if <dataclass_type>.<field_name> uses a default_factory.""" |
| 862 | +def _field_default_factory( |
| 863 | + dataclass_type: Type[Any], field_name: str) -> Callable[[], Any] | None: |
| 864 | + """Returns the default_factory of <dataclass_type>.<field_name> if present.""" |
853 | 865 | for field in dataclasses.fields(dataclass_type): |
854 | | - if field.name == field_name: |
855 | | - return field.default_factory != dataclasses.MISSING |
856 | | - return False |
| 866 | + if (field.name == field_name and |
| 867 | + field.default_factory != dataclasses.MISSING): |
| 868 | + return cast(Callable[[], Any], field.default_factory) |
| 869 | + return None |
| 870 | + |
| 871 | + |
| 872 | +def _is_resolvable(value: Any) -> bool: |
| 873 | + # TO DO: check we can roundtrip, ideally in a way that guarantees |
| 874 | + # compatibility with serialization code. |
| 875 | + return ( |
| 876 | + hasattr(value, '__module__') and |
| 877 | + hasattr(value, '__qualname__') and |
| 878 | + # Rules out anonymous objects like <lambda>, foo.<locals>.bar etc: |
| 879 | + '<' not in value.__qualname__) |
857 | 880 |
|
858 | 881 |
|
859 | 882 | BuildableT = TypeVar('BuildableT', bound=Buildable) |
|
0 commit comments