Skip to content

Commit 3b416cc

Browse files
panzhufengcopybara-github
authored andcommitted
Add positional args support for fdl.Config.
`config.posargs` returns the *args list and can be accessed directly. To access positional args: ```python v = config.posargs # the full list v = config.posargs[-1] # normal index v = config.posargs[:3] # slice index v = config[-1] # normal index v = config[:] # slice index ``` To modify positional args: ```python config.posargs = [1, 2] # assign to a new list, config.posargs = config.posargs.append(3) # append one item config.posargs = config.posargs + [3] # append one item config[0] = 0 config[:] = [1, 2, 3] ``` PiperOrigin-RevId: 549731020
1 parent 45b9d95 commit 3b416cc

File tree

6 files changed

+389
-65
lines changed

6 files changed

+389
-65
lines changed

fiddle/_src/building.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,21 @@ def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable,
9696

9797
def call_buildable(
9898
buildable: config_lib.Buildable,
99-
arguments: Dict[str, Any],
99+
kwargs: Dict[str, Any],
100100
*,
101101
current_path: daglish.Path,
102102
) -> Any:
103-
make_message = functools.partial(_make_message, current_path, buildable,
104-
arguments)
103+
"""Run the __build__ method on a Buildable given keyword arguments."""
104+
make_message = functools.partial(
105+
_make_message, current_path, buildable, kwargs
106+
)
107+
args = []
108+
for name in buildable.__signature_info__.positional_arg_names:
109+
if name in kwargs:
110+
args.append(kwargs.pop(name))
111+
args.extend(kwargs.pop('__args__', []))
105112
with reraised_exception.try_with_lazy_message(make_message):
106-
return buildable.__build__(**arguments)
113+
return buildable.__build__(*args, **kwargs)
107114

108115

109116
# Define typing overload for `build(Partial[T])`

fiddle/_src/config.py

Lines changed: 137 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
import copy
2323
import dataclasses
2424
import functools
25+
import inspect
2526
import types
26-
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
27+
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
2728

2829
from fiddle._src import daglish
2930
from fiddle._src import history
@@ -242,10 +243,15 @@ def __init__(
242243
arg_history.add_new_value('__fn_or_cls__', fn_or_cls)
243244
super().__setattr__('__argument_history__', arg_history)
244245
super().__setattr__('__argument_tags__', collections.defaultdict(set))
245-
arguments = signatures.SignatureInfo.signature_binding(
246-
fn_or_cls, *args, **kwargs
246+
arguments, positional_arguments = (
247+
signatures.SignatureInfo.signature_binding(fn_or_cls, *args, **kwargs)
247248
)
248249

250+
if positional_arguments:
251+
self.__arguments__['__args__'] = list(positional_arguments)
252+
for i, value in enumerate(positional_arguments):
253+
self[i] = value
254+
249255
for name, value in arguments.items():
250256
setattr(self, name, value)
251257

@@ -258,6 +264,7 @@ def __init__(
258264
def __init_callable__(
259265
self, fn_or_cls: Union['Buildable[T]', TypeOrCallableProducingT[T]]
260266
) -> None:
267+
"""Save information on `fn_or_cls` to the `Buildable`."""
261268
if isinstance(fn_or_cls, Buildable):
262269
raise ValueError(
263270
'Using the Buildable constructor to convert a buildable to a new '
@@ -273,9 +280,11 @@ def __init_callable__(
273280
super().__setattr__('__fn_or_cls__', fn_or_cls)
274281
super().__setattr__('__arguments__', {})
275282
signature = signatures.get_signature(fn_or_cls)
283+
# Several attributes are computed automatically by SignatureInfo during
284+
# `__post_init__`.
276285
super().__setattr__(
277286
'__signature_info__',
278-
signatures.SignatureInfo(signature),
287+
signatures.SignatureInfo(signature=signature),
279288
)
280289

281290
def __init_subclass__(cls):
@@ -311,6 +320,14 @@ def __path_elements__(self) -> Tuple[daglish.Attr]:
311320

312321
def __getattr__(self, name: str):
313322
"""Get parameter with given ``name``."""
323+
if name == 'posargs':
324+
if not self.__signature_info__.has_var_positional:
325+
raise TypeError(
326+
"This function doesn't have variadic positional arguments (*args). "
327+
'Please set other (including positional-only) arguments by name.'
328+
)
329+
330+
name = '__args__'
314331
value = self.__arguments__.get(name, _UNSET_SENTINEL)
315332

316333
if value is not _UNSET_SENTINEL:
@@ -340,9 +357,39 @@ def __getattr__(self, name: str):
340357
)
341358
raise AttributeError(msg)
342359

360+
def __setitem__(self, key: Any, value: Any):
361+
if not isinstance(key, (int, slice)):
362+
raise TypeError(
363+
'Setting arguments by index is only supported for variadic '
364+
"arguments (*args), like my_config[4] = 'foo'."
365+
)
366+
if not self.__signature_info__.has_var_positional:
367+
raise TypeError(
368+
"This function doesn't have variadic positional arguments (*args). "
369+
'Please set other (including positional-only) arguments by name.'
370+
)
371+
372+
if '__args__' not in self.__arguments__:
373+
self.__arguments__['__args__'] = []
374+
self.__argument_history__.add_new_value('__args__', [])
375+
self.__arguments__['__args__'][key] = value
376+
self.__argument_history__.add_new_value(
377+
'__args__', self.__arguments__['__args__']
378+
)
379+
380+
def __getitem__(self, key: Any):
381+
if not isinstance(key, slice):
382+
raise TypeError(
383+
'Getting arguments by index is only supported when using slice, '
384+
'for example `v = my_config[:2]`, or using the `posargs` attr '
385+
f'instead, like v = my_config[0]. Got {type(key)} type as key.'
386+
)
387+
return self.posargs[key]
388+
343389
def __setattr__(self, name: str, value: Any):
344390
"""Sets parameter ``name`` to ``value``."""
345-
391+
if name == 'posargs':
392+
name = '__args__'
346393
self.__signature_info__.validate_param_name(name, self.__fn_or_cls__)
347394

348395
if isinstance(value, TaggedValueCls):
@@ -362,6 +409,8 @@ def __setattr__(self, name: str, value: Any):
362409

363410
def __delattr__(self, name):
364411
"""Unsets parameter ``name``."""
412+
if name == 'posargs':
413+
name = '__args__'
365414
try:
366415
del self.__arguments__[name]
367416
self.__argument_history__.add_deleted_value(name)
@@ -488,9 +537,7 @@ def __getstate__(self):
488537
Dict of serialized state.
489538
"""
490539
result = dict(self.__dict__)
491-
result['__signature_info__'] = signatures.SignatureInfo( # pytype: disable=wrong-arg-types
492-
None, result['__signature_info__'].has_var_keyword
493-
)
540+
result['__signature_info__'] = signatures.SignatureInfo(None) # pytype: disable=wrong-arg-types
494541
return result
495542

496543
def __setstate__(self, state) -> None:
@@ -503,8 +550,10 @@ def __setstate__(self, state) -> None:
503550
"""
504551
self.__dict__.update(state) # Support unpickle.
505552
if self.__signature_info__.signature is None:
506-
self.__signature_info__.signature = signatures.get_signature(
507-
self.__fn_or_cls__
553+
signature = signatures.get_signature(self.__fn_or_cls__)
554+
super().__setattr__(
555+
'__signature_info__',
556+
signatures.SignatureInfo(signature=signature),
508557
)
509558

510559

@@ -637,6 +686,51 @@ def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
637686
return False
638687

639688

689+
def _align_var_positional_args(
690+
new_signature: inspect.Signature,
691+
original_args: Dict[str, Any],
692+
drop_invalid_args: bool,
693+
) -> List[str]:
694+
"""Returns the list of positional arguments to unpack."""
695+
args_start_index = -1
696+
for index, arg in enumerate(new_signature.parameters.keys()):
697+
if arg not in original_args.keys():
698+
args_start_index = index
699+
break
700+
if (args_start_index == -1 and original_args['__args__']) or (
701+
len(new_signature.parameters)
702+
< args_start_index + 1 + len(original_args['__args__'])
703+
):
704+
if not drop_invalid_args:
705+
raise ValueError(
706+
'new_callable does not have enough arguments when unpack'
707+
f' *args: {original_args["__args__"]} from the original'
708+
' buildable.'
709+
)
710+
arg_keys = list(new_signature.parameters.keys())[args_start_index:]
711+
return arg_keys
712+
713+
714+
def _expand_args_history(
715+
arg_keys: List[str], buildable: Buildable
716+
) -> List[List[history.HistoryEntry]]:
717+
"""Returns expanded history entries for positional arguments."""
718+
args_history = buildable.__argument_history__['__args__']
719+
expaneded_history = []
720+
for index in range(len(arg_keys)):
721+
expanded_entries = []
722+
for entry in args_history:
723+
new_entry = copy.copy(entry)
724+
if isinstance(new_entry.new_value, list):
725+
if index >= len(new_entry.new_value):
726+
new_entry.new_value = history.NOTSET
727+
else:
728+
new_entry.new_value = new_entry.new_value[index]
729+
expanded_entries.append(new_entry)
730+
expaneded_history.append(expanded_entries)
731+
return expaneded_history
732+
733+
640734
def update_callable(
641735
buildable: Buildable,
642736
new_callable: TypeOrCallableProducingT,
@@ -667,23 +761,40 @@ def update_callable(
667761
# Note: can't call `setattr` on all the args to validate them, because that
668762
# will result in duplicate history entries.
669763
original_args = buildable.__arguments__
670-
signature = signatures.get_signature(new_callable)
671-
if any(
672-
param.kind == param.VAR_POSITIONAL
673-
for param in signature.parameters.values()
674-
):
675-
raise NotImplementedError(
676-
'Variable positional arguments (aka `*args`) not supported.'
677-
)
678-
signature_info = signatures.SignatureInfo(signature)
679-
object.__setattr__(
680-
buildable,
681-
'__signature_info__',
682-
signature_info,
683-
)
684-
if not signature_info.has_var_keyword:
764+
new_signature = signatures.get_signature(new_callable)
765+
# Update the signature early so that we can set arguments by position.
766+
# Otherwise, parameter validation logics would complain about argument
767+
# name not exists.
768+
object.__setattr__(buildable, '__signature__', new_signature)
769+
new_signature_info = signatures.SignatureInfo(signature=new_signature)
770+
original_signature_info = buildable.__signature_info__
771+
object.__setattr__(buildable, '__signature_info__', new_signature_info)
772+
773+
if new_signature_info.has_var_positional:
774+
# If only new callable has positional arguments
775+
if not original_signature_info.has_var_positional:
776+
buildable.__arguments__['__args__'] = []
777+
buildable.__argument_history__.add_new_value('__args__', [])
778+
else:
779+
# If only the original config has *args
780+
if original_signature_info.has_var_positional:
781+
arg_keys = _align_var_positional_args(
782+
new_signature, original_args, drop_invalid_args
783+
)
784+
expanded_history = _expand_args_history(arg_keys, buildable)
785+
786+
for arg, value, history_extries in zip(
787+
arg_keys, original_args['__args__'], expanded_history
788+
):
789+
buildable.__setattr__(arg, value)
790+
buildable.__argument_history__[arg] = history_extries
791+
buildable.__delattr__('__args__')
792+
793+
if not new_signature_info.has_var_keyword:
685794
invalid_args = [
686-
arg for arg in original_args.keys() if arg not in signature.parameters
795+
arg
796+
for arg in original_args.keys()
797+
if arg not in new_signature.parameters and arg != '__args__'
687798
]
688799
if invalid_args:
689800
if drop_invalid_args:

0 commit comments

Comments
 (0)