|
23 | 23 | import dataclasses |
24 | 24 | import functools |
25 | 25 | import types |
26 | | -from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union |
| 26 | +from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union, cast |
27 | 27 |
|
28 | 28 | from fiddle._src import daglish |
29 | 29 | from fiddle._src import history |
@@ -341,14 +341,42 @@ def __getattr__(self, name: str): |
341 | 341 |
|
342 | 342 | if value is not _UNSET_SENTINEL: |
343 | 343 | return value |
344 | | - if dataclasses.is_dataclass( |
345 | | - self.__fn_or_cls__ |
346 | | - ) and _field_uses_default_factory(self.__fn_or_cls__, name): |
347 | | - raise ValueError( |
348 | | - "Can't get default value for dataclass field " |
349 | | - + f'{self.__fn_or_cls__.__qualname__}.{name} ' |
350 | | - + 'since it uses a default_factory.' |
351 | | - ) |
| 344 | + if dataclasses.is_dataclass(self.__fn_or_cls__) and ( |
| 345 | + default_factory := _field_default_factory(self.__fn_or_cls__, name) |
| 346 | + ): |
| 347 | + if hasattr(default_factory, 'as_buildable'): |
| 348 | + # The default_factory is a function wrapped by auto_config. In this case |
| 349 | + # it would appear reasonable to populate a child config by setting: |
| 350 | + # self.__arguments__[name] = default_factory.as_buildable() |
| 351 | + # However paxml.tools.fiddle / praxis.pax_fiddle unfortunately implement |
| 352 | + # some tricky magic of their own relating to auto_config and dataclass |
| 353 | + # default_factories, which would appear to be broken (or at least |
| 354 | + # interfered with) by any support for this here. It is likely a fairly |
| 355 | + # rare use case anyway outside of paxml, so we are choosing not to |
| 356 | + # support it here. |
| 357 | + raise ValueError( |
| 358 | + "We don't currently support exposing a sub-config to build the " |
| 359 | + f"default value for an auto_config'd default_factory (field {name} " |
| 360 | + f'of dataclass {self.__fn_or_cls__}.' |
| 361 | + ) |
| 362 | + elif _is_resolvable(default_factory): |
| 363 | + self.__arguments__[name] = Config(default_factory) |
| 364 | + elif isinstance(default_factory, functools.partial) and _is_resolvable( |
| 365 | + default_factory.func |
| 366 | + ): |
| 367 | + self.__arguments__[name] = Config( |
| 368 | + default_factory.func, |
| 369 | + *default_factory.args, |
| 370 | + **default_factory.keywords, |
| 371 | + ) |
| 372 | + else: |
| 373 | + raise ValueError( |
| 374 | + "Can't expose a sub-config to build default value for field " |
| 375 | + f'{name} of dataclass {self.__fn_or_cls__} since it uses an ' |
| 376 | + 'anonymous default_factory.' |
| 377 | + ) |
| 378 | + return self.__arguments__[name] |
| 379 | + |
352 | 380 | if param is not None and param.default is not param.empty: |
353 | 381 | return param.default |
354 | 382 | msg = f"No parameter '{name}' has been set on {self!r}." |
@@ -848,12 +876,27 @@ def __build__(self, /, *args: Any, **kwargs: Any) -> T: |
848 | 876 | return self.__fn_or_cls__(tags=self.tags, *args, **kwargs) |
849 | 877 |
|
850 | 878 |
|
851 | | -def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str): |
852 | | - """Returns true if <dataclass_type>.<field_name> uses a default_factory.""" |
| 879 | +def _field_default_factory( |
| 880 | + dataclass_type: Type[Any], field_name: str |
| 881 | +) -> Callable[[], Any] | None: |
| 882 | + """Returns the default_factory of <dataclass_type>.<field_name> if present.""" |
853 | 883 | for field in dataclasses.fields(dataclass_type): |
854 | | - if field.name == field_name: |
855 | | - return field.default_factory != dataclasses.MISSING |
856 | | - return False |
| 884 | + if ( |
| 885 | + field.name == field_name |
| 886 | + and field.default_factory != dataclasses.MISSING |
| 887 | + ): |
| 888 | + return cast(Callable[[], Any], field.default_factory) |
| 889 | + return None |
| 890 | + |
| 891 | + |
| 892 | +def _is_resolvable(value: Any) -> bool: |
| 893 | + return ( |
| 894 | + hasattr(value, '__module__') |
| 895 | + and hasattr(value, '__qualname__') |
| 896 | + and |
| 897 | + # Rules out anonymous objects like <lambda>, foo.<locals>.bar etc: |
| 898 | + '<' not in value.__qualname__ |
| 899 | + ) |
857 | 900 |
|
858 | 901 |
|
859 | 902 | BuildableT = TypeVar('BuildableT', bound=Buildable) |
|
0 commit comments