Skip to content

Commit 78b48c5

Browse files
panzhufengcopybara-github
authored andcommitted
Add positional args support for fdl.Config
PiperOrigin-RevId: 549731020
1 parent c882b09 commit 78b48c5

File tree

4 files changed

+251
-33
lines changed

4 files changed

+251
-33
lines changed

fiddle/_src/building.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,21 @@ def _make_message(current_path: daglish.Path, buildable: config_lib.Buildable,
9595

9696
def call_buildable(
9797
buildable: config_lib.Buildable,
98-
arguments: Dict[str, Any],
98+
kwargs: Dict[str, Any],
9999
*,
100100
current_path: daglish.Path,
101101
) -> Any:
102-
make_message = functools.partial(_make_message, current_path, buildable,
103-
arguments)
102+
"""Run the __build__ method on a Buildable given keyword arguments."""
103+
make_message = functools.partial(
104+
_make_message, current_path, buildable, kwargs
105+
)
106+
args = []
107+
for name in buildable.__positional_arg_names__:
108+
if name in kwargs:
109+
args.append(kwargs.pop(name))
110+
args.extend(kwargs.pop('__args__', []))
104111
with reraised_exception.try_with_lazy_message(make_message):
105-
return buildable.__build__(**arguments)
112+
return buildable.__build__(*args, **kwargs)
106113

107114

108115
# Define typing overload for `build(Partial[T])`

fiddle/_src/config.py

Lines changed: 121 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import itertools
2727
import logging
2828
import types
29-
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
29+
from typing import Any, Callable, Collection, Dict, FrozenSet, Generic, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union
3030

3131
from fiddle._src import arg_factory
3232
from fiddle._src import daglish
@@ -223,6 +223,8 @@ class Buildable(Generic[T], metaclass=abc.ABCMeta):
223223
__arguments__: Dict[str, Any]
224224
__argument_history__: history.History
225225
__argument_tags__: Dict[str, Set[tag_type.TagType]]
226+
__positional_arg_names__: List[str]
227+
__has_var_positional__: bool
226228
_has_var_keyword: bool
227229

228230
def __init__(
@@ -245,19 +247,23 @@ def __init__(
245247
super().__setattr__('__argument_history__', arg_history)
246248
super().__setattr__('__argument_tags__', collections.defaultdict(set))
247249

250+
positional_arguments = ()
248251
arguments = signature.bind_partial(*args, **kwargs).arguments
249252
for name in list(arguments.keys()): # Make a copy in case we mutate.
250253
param = signature.parameters[name]
251254
if param.kind == param.VAR_POSITIONAL:
252-
# TODO(b/197367863): Add *args support.
253-
err_msg = (
254-
'Variable positional arguments (aka `*args`) not supported. '
255-
f'Found param `{name}` in `{fn_or_cls}`.'
256-
)
257-
raise NotImplementedError(err_msg)
255+
positional_arguments = arguments.pop(param.name)
258256
elif param.kind == param.VAR_KEYWORD:
259257
arguments.update(arguments.pop(param.name))
260258

259+
if positional_arguments:
260+
self.__arguments__['__args__'] = list(positional_arguments)
261+
self.__argument_history__.add_new_value(
262+
'__args__', self.__arguments__['__args__']
263+
)
264+
265+
for i, value in enumerate(positional_arguments):
266+
self[i] = value
261267
for name, value in arguments.items():
262268
setattr(self, name, value)
263269

@@ -286,10 +292,25 @@ def __init_callable__(
286292
super().__setattr__('__arguments__', {})
287293
signature = signatures.get_signature(fn_or_cls)
288294
super().__setattr__('__signature__', signature)
289-
has_var_keyword = any(
290-
param.kind == param.VAR_KEYWORD
291-
for param in signature.parameters.values()
292-
)
295+
296+
# If *args exists, we must pass things before it in positional format. This
297+
# list tracks those arguments.
298+
maybe_positional_args = []
299+
300+
positional_only_args = []
301+
has_var_positional, has_var_keyword = False, False
302+
for param in signature.parameters.values():
303+
if param.kind == param.VAR_POSITIONAL:
304+
has_var_positional = True
305+
positional_only_args.extend(maybe_positional_args)
306+
elif param.kind == param.VAR_KEYWORD:
307+
has_var_keyword = True
308+
elif param.kind == param.POSITIONAL_ONLY:
309+
positional_only_args.append(param.name)
310+
elif param.kind == param.POSITIONAL_OR_KEYWORD:
311+
maybe_positional_args.append(param.name)
312+
super().__setattr__('__positional_arg_names__', positional_only_args)
313+
super().__setattr__('__has_var_positional__', has_var_positional)
293314
super().__setattr__('_has_var_keyword', has_var_keyword)
294315
return signature
295316

@@ -326,6 +347,14 @@ def __path_elements__(self) -> Tuple[daglish.Attr]:
326347

327348
def __getattr__(self, name: str):
328349
"""Get parameter with given ``name``."""
350+
if name == 'posargs':
351+
if not self.__has_var_positional__:
352+
raise TypeError(
353+
"This function doesn't have variadic positional arguments (*args). "
354+
'Please set other (including positional-only) arguments by name.'
355+
)
356+
357+
name = '__args__'
329358
value = self.__arguments__.get(name, _UNSET_SENTINEL)
330359

331360
if value is not _UNSET_SENTINEL:
@@ -387,6 +416,34 @@ def __validate_param_name__(self, name) -> None:
387416
)
388417
raise TypeError(err_msg)
389418

419+
def __setitem__(self, key: Any, value: Any):
420+
if not isinstance(key, (int, slice)):
421+
raise TypeError(
422+
'Setting arguments by index is only supported for variadic '
423+
"arguments (*args), like my_config[4] = 'foo'."
424+
)
425+
if not self.__has_var_positional__:
426+
raise TypeError(
427+
"This function doesn't have variadic positional arguments (*args). "
428+
'Please set other (including positional-only) arguments by name.'
429+
)
430+
431+
# In the future, use a specialized history-tracking list.
432+
if '__args__' not in self.__arguments__:
433+
self.__arguments__['__args__'] = []
434+
self.__argument_history__.add_new_value('__args__', [])
435+
args = self.__arguments__['__args__']
436+
args[key] = value
437+
438+
def __getitem__(self, key: Any):
439+
if not isinstance(key, slice):
440+
raise TypeError(
441+
'Getting arguments by index is only supported when using slice, '
442+
'for example `v = my_config[:2]`, or using the `posargs` attr '
443+
f'instead, like v = my_config[0]. Got {type(key)} type as key.'
444+
)
445+
return self.posargs[key]
446+
390447
def __setattr__(self, name: str, value: Any):
391448
"""Sets parameter ``name`` to ``value``."""
392449

@@ -950,13 +1007,63 @@ def update_callable(
9501007
# will result in duplicate history entries.
9511008
original_args = buildable.__arguments__
9521009
signature = signatures.get_signature(new_callable)
1010+
# Update the signature early so that we can set arguments by position
1011+
object.__setattr__(buildable, '__signature__', signature)
1012+
9531013
if any(
9541014
param.kind == param.VAR_POSITIONAL
9551015
for param in signature.parameters.values()
9561016
):
957-
raise NotImplementedError(
958-
'Variable positional arguments (aka `*args`) not supported.'
959-
)
1017+
# Both callables have *args
1018+
if buildable.__has_var_positional__:
1019+
args_ptr = -1
1020+
consumed_args = 0
1021+
for idx, arg in enumerate(signature.parameters.keys()):
1022+
if arg not in original_args.keys():
1023+
args_ptr = idx
1024+
break
1025+
all_args_key = list(signature.parameters.keys())
1026+
while args_ptr < len(all_args_key):
1027+
key = all_args_key[args_ptr]
1028+
param = signature.parameters[key]
1029+
if param.kind == param.VAR_POSITIONAL:
1030+
break
1031+
else:
1032+
value = original_args['__args__'][args_ptr]
1033+
buildable.__setattr__(key, value)
1034+
args_ptr += 1
1035+
consumed_args += 1
1036+
1037+
buildable.__arguments__['__args__'] = buildable.__arguments__['__args__'][
1038+
consumed_args:
1039+
]
1040+
# Only new callable has *args
1041+
else:
1042+
object.__setattr__(buildable, '__args__', [])
1043+
buildable.__argument_history__.add_new_value('__args__', [])
1044+
else:
1045+
# If only the original config has *args
1046+
if buildable.__has_var_positional__:
1047+
args_start_at = -1
1048+
for idx, arg in enumerate(signature.parameters.keys()):
1049+
if arg not in original_args.keys():
1050+
args_start_at = idx
1051+
break
1052+
1053+
if len(signature.parameters) < args_start_at + len(
1054+
original_args['__args__']
1055+
):
1056+
if not drop_invalid_args:
1057+
raise ValueError(
1058+
'new_callable does not have enough arguments when unpack *args: '
1059+
f'{original_args["__args__"]} from the original buildable.'
1060+
)
1061+
arg_keys = list(signature.parameters.keys())[args_start_at:]
1062+
for arg, value in zip(arg_keys, original_args['__args__']):
1063+
buildable.__setattr__(arg, value)
1064+
del buildable.__args__
1065+
object.__setattr__(buildable, '__has_var_positional__', False)
1066+
9601067
has_var_keyword = any(
9611068
param.kind == param.VAR_KEYWORD for param in signature.parameters.values()
9621069
)
@@ -976,7 +1083,6 @@ def update_callable(
9761083
)
9771084

9781085
object.__setattr__(buildable, '__fn_or_cls__', new_callable)
979-
object.__setattr__(buildable, '__signature__', signature)
9801086
object.__setattr__(buildable, '_has_var_keyword', has_var_keyword)
9811087
buildable.__argument_history__.add_new_value('__fn_or_cls__', new_callable)
9821088

fiddle/_src/config_test.py

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,84 @@ def test_config_for_functions_with_var_args_and_kwargs(self):
224224
'kwargs': 'kwarg_called_kwarg'
225225
})
226226

227+
# "args" below refer to positional arguments, typcally `*args``
228+
def test_args_config_access(self):
229+
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
230+
231+
with self.subTest('ordered_arguments'):
232+
self.assertEqual(
233+
fdl.ordered_arguments(fn_config),
234+
{
235+
'arg1': 'foo',
236+
'__args__': ['bar', 'baz'],
237+
},
238+
)
239+
240+
with self.subTest('posargs_access'):
241+
self.assertEqual(fn_config.posargs[0], 'bar')
242+
self.assertEqual(fn_config.posargs[1], 'baz')
243+
self.assertSequenceEqual(fn_config.posargs, ['bar', 'baz'])
244+
245+
with self.subTest('index_access'):
246+
with self.assertRaisesRegex(
247+
TypeError,
248+
'Getting arguments by index is only supported when using slice',
249+
):
250+
_ = fn_config[0]
251+
252+
with self.subTest('slice_access'):
253+
self.assertEmpty(fn_config[:0])
254+
self.assertSequenceEqual(fn_config[:1], ['bar'])
255+
self.assertSequenceEqual(fn_config[:], ['bar', 'baz'])
256+
257+
def test_args_config_posargs_append(self):
258+
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
259+
fn_config.posargs.append('foo')
260+
self.assertSequenceEqual(fn_config.posargs, ['bar', 'baz', 'foo'])
261+
262+
def test_args_config_slice_mutation(self):
263+
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
264+
self.assertSequenceEqual(fn_config[:], ['bar', 'baz'])
265+
fn_config[:1] = ['zero', 'one']
266+
self.assertSequenceEqual(fn_config[:], ['zero', 'one', 'baz'])
267+
268+
def test_args_config_shallow_copy(self):
269+
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
270+
self.assertLen(fn_config[:], 2)
271+
a_copy = fn_config[:]
272+
a_copy.append('foo')
273+
self.assertLen(fn_config[:], 2)
274+
self.assertLen(a_copy, 3)
275+
276+
def test_index_mutation(self):
277+
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
278+
fn_config[0] = 'foo'
279+
self.assertEqual(fn_config.posargs[0], 'foo')
280+
fn_config[-1] = 'last'
281+
self.assertLen(fn_config.posargs, 2)
282+
self.assertEqual(fn_config.posargs[1], 'last')
283+
self.assertEqual(fn_config.posargs[-1], 'last')
284+
285+
def test_index_out_of_range(self):
286+
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
287+
self.assertLen(fn_config[:], 2)
288+
with self.assertRaisesRegex(
289+
IndexError, 'list assignment index out of range'
290+
):
291+
fn_config[2] = 'index-2'
292+
293+
def test_args_config_build(self):
294+
fn_config = fdl.Config(fn_with_var_args, 'foo', 'bar', 'baz')
295+
fn_args = fdl.build(fn_config)
296+
self.assertEqual(
297+
fn_args,
298+
{
299+
'arg1': 'foo',
300+
'args': ('bar', 'baz'),
301+
'kwarg1': None,
302+
},
303+
)
304+
227305
def test_config_for_dicts(self):
228306
dict_config = fdl.Config(dict, a=1, b=2)
229307
dict_config.c = 3
@@ -858,12 +936,6 @@ def test_nonexistent_var_args_parameter_error(self):
858936
with self.assertRaisesRegex(TypeError, expected_msg):
859937
fn_config.args = (1, 2, 3)
860938

861-
def test_unsupported_var_args_error(self):
862-
expected_msg = (r'Variable positional arguments \(aka `\*args`\) not '
863-
r'supported\.')
864-
with self.assertRaisesRegex(NotImplementedError, expected_msg):
865-
fdl.Config(fn_with_var_args, 1, 2, 3)
866-
867939
def test_build_inside_build(self):
868940

869941
def inner_build(x: int) -> str:
@@ -1211,11 +1283,37 @@ def test_update_callable_new_kwargs(self):
12111283
}
12121284
}, fdl.build(cfg))
12131285

1214-
def test_update_callable_varargs(self):
1215-
cfg = fdl.Config(fn_with_var_kwargs, 1, 2)
1216-
with self.assertRaisesRegex(NotImplementedError,
1217-
'Variable positional arguments'):
1218-
fdl.update_callable(cfg, fn_with_var_args_and_kwargs)
1286+
def test_update_args_to_args(self):
1287+
cfg = fdl.Config(fn_with_var_args, 1, 2, kwarg1=3)
1288+
fdl.update_callable(cfg, fn_with_var_args_and_kwargs)
1289+
self.assertEqual(
1290+
cfg.__arguments__, {'arg1': 1, '__args__': [2], 'kwarg1': 3}
1291+
)
1292+
self.assertEqual(
1293+
{'arg1': 1, 'args': (2,), 'kwarg1': 3, 'kwargs': {}}, fdl.build(cfg)
1294+
)
1295+
1296+
def test_update_args_to_no_args(self):
1297+
cfg = fdl.Config(fn_with_var_args, 1, 2, kwarg1=3)
1298+
fdl.update_callable(cfg, basic_fn)
1299+
cfg.arg2 = 22
1300+
self.assertEqual(cfg.__arguments__, {'arg1': 1, 'arg2': 22, 'kwarg1': 3})
1301+
self.assertEqual(
1302+
{'arg1': 1, 'arg2': 22, 'kwarg1': 3, 'kwarg2': None}, fdl.build(cfg)
1303+
)
1304+
1305+
def test_update_args_kwargs(self):
1306+
def my_fn(*args, **kwargs):
1307+
del args, kwargs
1308+
1309+
cfg = fdl.Config(my_fn, 1, 2, 3, kwarg1=4, kwarg2=5)
1310+
cfg.posargs[0] = 10
1311+
cfg.kwarg1 = 40
1312+
config_lib.update_callable(cfg, fn_with_var_args_and_kwargs)
1313+
self.assertEqual(
1314+
cfg.__arguments__,
1315+
{'arg1': 10, '__args__': [2, 3], 'kwarg1': 40, 'kwarg2': 5},
1316+
)
12191317

12201318
def test_get_callable(self):
12211319
cfg = fdl.Config(basic_fn)

fiddle/_src/mutate_buildable.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,16 @@
1717

1818
from fiddle._src import config
1919

20-
_buildable_internals_keys = ('__fn_or_cls__', '__signature__', '__arguments__',
21-
'_has_var_keyword', '__argument_tags__',
22-
'__argument_history__')
20+
_buildable_internals_keys = (
21+
'__fn_or_cls__',
22+
'__signature__',
23+
'__arguments__',
24+
'_has_var_keyword',
25+
'__argument_tags__',
26+
'__argument_history__',
27+
'__has_var_positional__',
28+
'__positional_arg_names__',
29+
)
2330

2431

2532
def move_buildable_internals(*, source: config.Buildable,

0 commit comments

Comments
 (0)