2222import copy
2323import dataclasses
2424import functools
25+ import inspect
2526import types
26- from typing import Any , Callable , Collection , Dict , FrozenSet , Generic , Iterable , Mapping , NamedTuple , Optional , Set , Tuple , Type , TypeVar , Union
27+ from typing import Any , Callable , Collection , Dict , FrozenSet , Generic , Iterable , List , Mapping , NamedTuple , Optional , Set , Tuple , Type , TypeVar , Union
2728
2829from fiddle ._src import daglish
2930from fiddle ._src import history
@@ -242,10 +243,15 @@ def __init__(
242243 arg_history .add_new_value ('__fn_or_cls__' , fn_or_cls )
243244 super ().__setattr__ ('__argument_history__' , arg_history )
244245 super ().__setattr__ ('__argument_tags__' , collections .defaultdict (set ))
245- arguments = signatures . SignatureInfo . signature_binding (
246- fn_or_cls , * args , ** kwargs
246+ arguments , positional_arguments = (
247+ signatures . SignatureInfo . signature_binding ( fn_or_cls , * args , ** kwargs )
247248 )
248249
250+ if positional_arguments :
251+ self .__arguments__ ['__args__' ] = list (positional_arguments )
252+ for i , value in enumerate (positional_arguments ):
253+ self [i ] = value
254+
249255 for name , value in arguments .items ():
250256 setattr (self , name , value )
251257
@@ -258,6 +264,7 @@ def __init__(
258264 def __init_callable__ (
259265 self , fn_or_cls : Union ['Buildable[T]' , TypeOrCallableProducingT [T ]]
260266 ) -> None :
267+ """Save information on `fn_or_cls` to the `Buildable`."""
261268 if isinstance (fn_or_cls , Buildable ):
262269 raise ValueError (
263270 'Using the Buildable constructor to convert a buildable to a new '
@@ -273,9 +280,11 @@ def __init_callable__(
273280 super ().__setattr__ ('__fn_or_cls__' , fn_or_cls )
274281 super ().__setattr__ ('__arguments__' , {})
275282 signature = signatures .get_signature (fn_or_cls )
283+ # Several attributes are computed automatically by SignatureInfo during
284+ # `__post_init__`.
276285 super ().__setattr__ (
277286 '__signature_info__' ,
278- signatures .SignatureInfo (signature ),
287+ signatures .SignatureInfo (signature = signature ),
279288 )
280289
281290 def __init_subclass__ (cls ):
@@ -311,6 +320,14 @@ def __path_elements__(self) -> Tuple[daglish.Attr]:
311320
312321 def __getattr__ (self , name : str ):
313322 """Get parameter with given ``name``."""
323+ if name == 'posargs' :
324+ if not self .__signature_info__ .has_var_positional :
325+ raise TypeError (
326+ "This function doesn't have variadic positional arguments (*args). "
327+ 'Please set other (including positional-only) arguments by name.'
328+ )
329+
330+ name = '__args__'
314331 value = self .__arguments__ .get (name , _UNSET_SENTINEL )
315332
316333 if value is not _UNSET_SENTINEL :
@@ -340,9 +357,39 @@ def __getattr__(self, name: str):
340357 )
341358 raise AttributeError (msg )
342359
360+ def __setitem__ (self , key : Any , value : Any ):
361+ if not isinstance (key , (int , slice )):
362+ raise TypeError (
363+ 'Setting arguments by index is only supported for variadic '
364+ "arguments (*args), like my_config[4] = 'foo'."
365+ )
366+ if not self .__signature_info__ .has_var_positional :
367+ raise TypeError (
368+ "This function doesn't have variadic positional arguments (*args). "
369+ 'Please set other (including positional-only) arguments by name.'
370+ )
371+
372+ if '__args__' not in self .__arguments__ :
373+ self .__arguments__ ['__args__' ] = []
374+ self .__argument_history__ .add_new_value ('__args__' , [])
375+ self .__arguments__ ['__args__' ][key ] = value
376+ self .__argument_history__ .add_new_value (
377+ '__args__' , self .__arguments__ ['__args__' ]
378+ )
379+
380+ def __getitem__ (self , key : Any ):
381+ if not isinstance (key , slice ):
382+ raise TypeError (
383+ 'Getting arguments by index is only supported when using slice, '
384+ 'for example `v = my_config[:2]`, or using the `posargs` attr '
385+ f'instead, like v = my_config[0]. Got { type (key )} type as key.'
386+ )
387+ return self .posargs [key ]
388+
343389 def __setattr__ (self , name : str , value : Any ):
344390 """Sets parameter ``name`` to ``value``."""
345-
391+ if name == 'posargs' :
392+ name = '__args__'
346393 self .__signature_info__ .validate_param_name (name , self .__fn_or_cls__ )
347394
348395 if isinstance (value , TaggedValueCls ):
@@ -362,6 +409,8 @@ def __setattr__(self, name: str, value: Any):
362409
363410 def __delattr__ (self , name ):
364411 """Unsets parameter ``name``."""
412+ if name == 'posargs' :
413+ name = '__args__'
365414 try :
366415 del self .__arguments__ [name ]
367416 self .__argument_history__ .add_deleted_value (name )
@@ -488,9 +537,7 @@ def __getstate__(self):
488537 Dict of serialized state.
489538 """
490539 result = dict (self .__dict__ )
491- result ['__signature_info__' ] = signatures .SignatureInfo ( # pytype: disable=wrong-arg-types
492- None , result ['__signature_info__' ].has_var_keyword
493- )
540+ result ['__signature_info__' ] = signatures .SignatureInfo (None ) # pytype: disable=wrong-arg-types
494541 return result
495542
496543 def __setstate__ (self , state ) -> None :
@@ -503,8 +550,10 @@ def __setstate__(self, state) -> None:
503550 """
504551 self .__dict__ .update (state ) # Support unpickle.
505552 if self .__signature_info__ .signature is None :
506- self .__signature_info__ .signature = signatures .get_signature (
507- self .__fn_or_cls__
553+ signature = signatures .get_signature (self .__fn_or_cls__ )
554+ super ().__setattr__ (
555+ '__signature_info__' ,
556+ signatures .SignatureInfo (signature = signature ),
508557 )
509558
510559
@@ -637,6 +686,51 @@ def _field_uses_default_factory(dataclass_type: Type[Any], field_name: str):
637686 return False
638687
639688
689+ def _align_var_positional_args (
690+ new_signature : inspect .Signature ,
691+ original_args : Dict [str , Any ],
692+ drop_invalid_args : bool ,
693+ ) -> List [str ]:
694+ """Returns the list of positional arguments to unpack."""
695+ args_start_index = - 1
696+ for index , arg in enumerate (new_signature .parameters .keys ()):
697+ if arg not in original_args .keys ():
698+ args_start_index = index
699+ break
700+ if (args_start_index == - 1 and original_args ['__args__' ]) or (
701+ len (new_signature .parameters )
702+ < args_start_index + 1 + len (original_args ['__args__' ])
703+ ):
704+ if not drop_invalid_args :
705+ raise ValueError (
706+ 'new_callable does not have enough arguments when unpack'
707+ f' *args: { original_args ["__args__" ]} from the original'
708+ ' buildable.'
709+ )
710+ arg_keys = list (new_signature .parameters .keys ())[args_start_index :]
711+ return arg_keys
712+
713+
714+ def _expand_args_history (
715+ arg_keys : List [str ], buildable : Buildable
716+ ) -> List [List [history .HistoryEntry ]]:
717+ """Returns expanded history entries for positional arguments."""
718+ args_history = buildable .__argument_history__ ['__args__' ]
719+ expaneded_history = []
720+ for index in range (len (arg_keys )):
721+ expanded_entries = []
722+ for entry in args_history :
723+ new_entry = copy .copy (entry )
724+ if isinstance (new_entry .new_value , list ):
725+ if index >= len (new_entry .new_value ):
726+ new_entry .new_value = history .NOTSET
727+ else :
728+ new_entry .new_value = new_entry .new_value [index ]
729+ expanded_entries .append (new_entry )
730+ expaneded_history .append (expanded_entries )
731+ return expaneded_history
732+
733+
640734def update_callable (
641735 buildable : Buildable ,
642736 new_callable : TypeOrCallableProducingT ,
@@ -667,23 +761,40 @@ def update_callable(
667761 # Note: can't call `setattr` on all the args to validate them, because that
668762 # will result in duplicate history entries.
669763 original_args = buildable .__arguments__
670- signature = signatures .get_signature (new_callable )
671- if any (
672- param .kind == param .VAR_POSITIONAL
673- for param in signature .parameters .values ()
674- ):
675- raise NotImplementedError (
676- 'Variable positional arguments (aka `*args`) not supported.'
677- )
678- signature_info = signatures .SignatureInfo (signature )
679- object .__setattr__ (
680- buildable ,
681- '__signature_info__' ,
682- signature_info ,
683- )
684- if not signature_info .has_var_keyword :
764+ new_signature = signatures .get_signature (new_callable )
765+ # Update the signature early so that we can set arguments by position.
766+ # Otherwise, parameter validation logics would complain about argument
767+ # name not exists.
768+ object .__setattr__ (buildable , '__signature__' , new_signature )
769+ new_signature_info = signatures .SignatureInfo (signature = new_signature )
770+ original_signature_info = buildable .__signature_info__
771+ object .__setattr__ (buildable , '__signature_info__' , new_signature_info )
772+
773+ if new_signature_info .has_var_positional :
774+ # If only new callable has positional arguments
775+ if not original_signature_info .has_var_positional :
776+ buildable .__arguments__ ['__args__' ] = []
777+ buildable .__argument_history__ .add_new_value ('__args__' , [])
778+ else :
779+ # If only the original config has *args
780+ if original_signature_info .has_var_positional :
781+ arg_keys = _align_var_positional_args (
782+ new_signature , original_args , drop_invalid_args
783+ )
784+ expanded_history = _expand_args_history (arg_keys , buildable )
785+
786+ for arg , value , history_extries in zip (
787+ arg_keys , original_args ['__args__' ], expanded_history
788+ ):
789+ buildable .__setattr__ (arg , value )
790+ buildable .__argument_history__ [arg ] = history_extries
791+ buildable .__delattr__ ('__args__' )
792+
793+ if not new_signature_info .has_var_keyword :
685794 invalid_args = [
686- arg for arg in original_args .keys () if arg not in signature .parameters
795+ arg
796+ for arg in original_args .keys ()
797+ if arg not in new_signature .parameters and arg != '__args__'
687798 ]
688799 if invalid_args :
689800 if drop_invalid_args :
0 commit comments