2626import itertools
2727import logging
2828import types
29- from typing import Any , Callable , Collection , Dict , FrozenSet , Generic , Iterable , Mapping , NamedTuple , Optional , Set , Tuple , Type , TypeVar , Union
29+ from typing import Any , Callable , Collection , Dict , FrozenSet , Generic , Iterable , List , Mapping , NamedTuple , Optional , Set , Tuple , Type , TypeVar , Union
3030
3131from fiddle ._src import arg_factory
3232from fiddle ._src import daglish
@@ -223,6 +223,8 @@ class Buildable(Generic[T], metaclass=abc.ABCMeta):
223223 __arguments__ : Dict [str , Any ]
224224 __argument_history__ : history .History
225225 __argument_tags__ : Dict [str , Set [tag_type .TagType ]]
226+ __positional_arg_names__ : List [str ]
227+ __has_var_positional__ : bool
226228 _has_var_keyword : bool
227229
228230 def __init__ (
@@ -245,19 +247,23 @@ def __init__(
245247 super ().__setattr__ ('__argument_history__' , arg_history )
246248 super ().__setattr__ ('__argument_tags__' , collections .defaultdict (set ))
247249
250+ positional_arguments = ()
248251 arguments = signature .bind_partial (* args , ** kwargs ).arguments
249252 for name in list (arguments .keys ()): # Make a copy in case we mutate.
250253 param = signature .parameters [name ]
251254 if param .kind == param .VAR_POSITIONAL :
252- # TODO(b/197367863): Add *args support.
253- err_msg = (
254- 'Variable positional arguments (aka `*args`) not supported. '
255- f'Found param `{ name } ` in `{ fn_or_cls } `.'
256- )
257- raise NotImplementedError (err_msg )
255+ positional_arguments = arguments .pop (param .name )
258256 elif param .kind == param .VAR_KEYWORD :
259257 arguments .update (arguments .pop (param .name ))
260258
259+ if positional_arguments :
260+ self .__arguments__ ['__args__' ] = list (positional_arguments )
261+ self .__argument_history__ .add_new_value (
262+ '__args__' , self .__arguments__ ['__args__' ]
263+ )
264+
265+ for i , value in enumerate (positional_arguments ):
266+ self [i ] = value
261267 for name , value in arguments .items ():
262268 setattr (self , name , value )
263269
@@ -286,10 +292,25 @@ def __init_callable__(
286292 super ().__setattr__ ('__arguments__' , {})
287293 signature = signatures .get_signature (fn_or_cls )
288294 super ().__setattr__ ('__signature__' , signature )
289- has_var_keyword = any (
290- param .kind == param .VAR_KEYWORD
291- for param in signature .parameters .values ()
292- )
295+
296+ # If *args exists, we must pass things before it in positional format. This
297+ # list tracks those arguments.
298+ maybe_positional_args = []
299+
300+ positional_only_args = []
301+ has_var_positional , has_var_keyword = False , False
302+ for param in signature .parameters .values ():
303+ if param .kind == param .VAR_POSITIONAL :
304+ has_var_positional = True
305+ positional_only_args .extend (maybe_positional_args )
306+ elif param .kind == param .VAR_KEYWORD :
307+ has_var_keyword = True
308+ elif param .kind == param .POSITIONAL_ONLY :
309+ positional_only_args .append (param .name )
310+ elif param .kind == param .POSITIONAL_OR_KEYWORD :
311+ maybe_positional_args .append (param .name )
312+ super ().__setattr__ ('__positional_arg_names__' , positional_only_args )
313+ super ().__setattr__ ('__has_var_positional__' , has_var_positional )
293314 super ().__setattr__ ('_has_var_keyword' , has_var_keyword )
294315 return signature
295316
@@ -326,6 +347,14 @@ def __path_elements__(self) -> Tuple[daglish.Attr]:
326347
327348 def __getattr__ (self , name : str ):
328349 """Get parameter with given ``name``."""
350+ if name == 'posargs' :
351+ if not self .__has_var_positional__ :
352+ raise TypeError (
353+ "This function doesn't have variadic positional arguments (*args). "
354+ 'Please set other (including positional-only) arguments by name.'
355+ )
356+
357+ name = '__args__'
329358 value = self .__arguments__ .get (name , _UNSET_SENTINEL )
330359
331360 if value is not _UNSET_SENTINEL :
@@ -387,6 +416,34 @@ def __validate_param_name__(self, name) -> None:
387416 )
388417 raise TypeError (err_msg )
389418
419+ def __setitem__ (self , key : Any , value : Any ):
420+ if not isinstance (key , (int , slice )):
421+ raise TypeError (
422+ 'Setting arguments by index is only supported for variadic '
423+ "arguments (*args), like my_config[4] = 'foo'."
424+ )
425+ if not self .__has_var_positional__ :
426+ raise TypeError (
427+ "This function doesn't have variadic positional arguments (*args). "
428+ 'Please set other (including positional-only) arguments by name.'
429+ )
430+
431+ # In the future, use a specialized history-tracking list.
432+ if '__args__' not in self .__arguments__ :
433+ self .__arguments__ ['__args__' ] = []
434+ self .__argument_history__ .add_new_value ('__args__' , [])
435+ args = self .__arguments__ ['__args__' ]
436+ args [key ] = value
437+
438+ def __getitem__ (self , key : Any ):
439+ if not isinstance (key , slice ):
440+ raise TypeError (
441+ 'Getting arguments by index is only supported when using slice, '
442+ 'for example `v = my_config[:2]`, or using the `posargs` attr '
443+ f'instead, like v = my_config[0]. Got { type (key )} type as key.'
444+ )
445+ return self .posargs [key ]
446+
390447 def __setattr__ (self , name : str , value : Any ):
391448 """Sets parameter ``name`` to ``value``."""
392449
@@ -950,13 +1007,63 @@ def update_callable(
9501007 # will result in duplicate history entries.
9511008 original_args = buildable .__arguments__
9521009 signature = signatures .get_signature (new_callable )
1010+ # Update the signature early so that we can set arguments by position
1011+ object .__setattr__ (buildable , '__signature__' , signature )
1012+
9531013 if any (
9541014 param .kind == param .VAR_POSITIONAL
9551015 for param in signature .parameters .values ()
9561016 ):
957- raise NotImplementedError (
958- 'Variable positional arguments (aka `*args`) not supported.'
959- )
1017+ # Both callables have *args
1018+ if buildable .__has_var_positional__ :
1019+ args_ptr = - 1
1020+ consumed_args = 0
1021+ for idx , arg in enumerate (signature .parameters .keys ()):
1022+ if arg not in original_args .keys ():
1023+ args_ptr = idx
1024+ break
1025+ all_args_key = list (signature .parameters .keys ())
1026+ while args_ptr < len (all_args_key ):
1027+ key = all_args_key [args_ptr ]
1028+ param = signature .parameters [key ]
1029+ if param .kind == param .VAR_POSITIONAL :
1030+ break
1031+ else :
1032+ value = original_args ['__args__' ][args_ptr ]
1033+ buildable .__setattr__ (key , value )
1034+ args_ptr += 1
1035+ consumed_args += 1
1036+
1037+ buildable .__arguments__ ['__args__' ] = buildable .__arguments__ ['__args__' ][
1038+ consumed_args :
1039+ ]
1040+ # Only new callable has *args
1041+ else :
1042+ object .__setattr__ (buildable , '__args__' , [])
1043+ buildable .__argument_history__ .add_new_value ('__args__' , [])
1044+ else :
1045+ # If only the original config has *args
1046+ if buildable .__has_var_positional__ :
1047+ args_start_at = - 1
1048+ for idx , arg in enumerate (signature .parameters .keys ()):
1049+ if arg not in original_args .keys ():
1050+ args_start_at = idx
1051+ break
1052+
1053+ if len (signature .parameters ) < args_start_at + len (
1054+ original_args ['__args__' ]
1055+ ):
1056+ if not drop_invalid_args :
1057+ raise ValueError (
1058+ 'new_callable does not have enough arguments when unpack *args: '
1059+ f'{ original_args ["__args__" ]} from the original buildable.'
1060+ )
1061+ arg_keys = list (signature .parameters .keys ())[args_start_at :]
1062+ for arg , value in zip (arg_keys , original_args ['__args__' ]):
1063+ buildable .__setattr__ (arg , value )
1064+ del buildable .__args__
1065+ object .__setattr__ (buildable , '__has_var_positional__' , False )
1066+
9601067 has_var_keyword = any (
9611068 param .kind == param .VAR_KEYWORD for param in signature .parameters .values ()
9621069 )
@@ -976,7 +1083,6 @@ def update_callable(
9761083 )
9771084
9781085 object .__setattr__ (buildable , '__fn_or_cls__' , new_callable )
979- object .__setattr__ (buildable , '__signature__' , signature )
9801086 object .__setattr__ (buildable , '_has_var_keyword' , has_var_keyword )
9811087 buildable .__argument_history__ .add_new_value ('__fn_or_cls__' , new_callable )
9821088
0 commit comments