77
88import yaml
99import argparse
10+ import sys
11+ import collections
12+ import inspect
1013
1114
1215def non_string_warning (func ):
@@ -32,6 +35,7 @@ def warning_wrapper(config, key, *args, **kwargs):
3235 key , type (key )), RuntimeWarning )
3336
3437 return func (config , key , * args , ** kwargs )
38+
3539 return warning_wrapper
3640
3741
@@ -513,6 +517,130 @@ def create_from_str(cls, data, formatter=yaml.load, decoder_cls=Decoder,
513517 ** kwargs )
514518 return config
515519
520+ def create_argparser (self ):
521+ '''
522+ Creates an argparser for all values in the config
523+ Following the pattern: `--training.learning_rate 1234`
524+
525+ Returns
526+ -------
527+ argparse.ArgumentParser
528+ parser for all variables in the config
529+ '''
530+ parser = argparse .ArgumentParser (allow_abbrev = False )
531+
532+ def add_val (dict_like , prefix = '' ):
533+ for key , val in dict_like .items ():
534+ name = "--{}" .format (prefix + key )
535+ if val is None :
536+ parser .add_argument (name )
537+ else :
538+ if isinstance (val , int ):
539+ parser .add_argument (name , type = type (val ))
540+ elif isinstance (val , collections .Mapping ):
541+ add_val (val , prefix = key + '.' )
542+ elif isinstance (val , collections .Iterable ):
543+ if len (val ) > 0 and type (val [0 ]) != type :
544+ parser .add_argument (name , type = type (val [0 ]))
545+ else :
546+ parser .add_argument (name )
547+ elif issubclass (val , type ) or inspect .isclass (val ):
548+ parser .add_argument (name , type = val )
549+ else :
550+ parser .add_argument (name , type = type (val ))
551+
552+ add_val (self )
553+ return parser
554+
555+ @staticmethod
556+ def _add_unknown_args (unknown_args ):
557+ '''
558+ Can add unknown args as parsed by argparsers method
559+ `parse_unknown_args`.
560+
561+ Parameters
562+ ------
563+ unknown_args : list
564+ list of unknown args
565+ Returns
566+ ------
567+ Config
568+ a config of the parsed args
569+ '''
570+ # first element in the list must be a key
571+ if not isinstance (unknown_args [0 ], str ):
572+ unknown_args = [str (arg ) for arg in unknown_args ]
573+ if not unknown_args [0 ].startswith ('--' ):
574+ raise ValueError
575+
576+ args = Config ()
577+ # take first key
578+ key = unknown_args [0 ][2 :]
579+ idx , done , val = 1 , False , []
580+ while not done :
581+ try :
582+ item = unknown_args [idx ]
583+ except IndexError :
584+ done = True
585+ if item .startswith ('--' ) or done :
586+ # save key with its value
587+ if len (val ) == 0 :
588+ # key is used as flag
589+ args [key ] = True
590+ elif len (val ) == 1 :
591+ args [key ] = val [0 ]
592+ else :
593+ args [key ] = val
594+ # new key and flush data
595+ key = item [2 :]
596+ val = []
597+ else :
598+ val .append (item )
599+ idx += 1
600+ return args
601+
602+ def update_from_argparse (self , parser = None , add_unknown_items = False ):
603+ '''
604+ Updates the config with all values from the command line.
605+ Following the pattern: `--training.learning_rate 1234`
606+
607+ Raises
608+ ------
609+ TypeError
610+ raised if another datatype than currently in the config is parsed
611+ Returns
612+ -------
613+ dict
614+ dictionary containing only updated arguments
615+ '''
616+
617+ if len (sys .argv ) > 1 :
618+ if not parser :
619+ parser = self .create_argparser ()
620+
621+ params , unknown = parser .parse_known_args ()
622+ params = vars (params )
623+ if unknown and not add_unknown_items :
624+ warnings .warn (
625+ "Called with unknown arguments: {} "
626+ "They will not be stored if you do not set "
627+ "`add_unknown_items` to true." .format (unknown ),
628+ RuntimeWarning )
629+
630+ new_params = Config ()
631+ for key , val in params .items ():
632+ if val is None :
633+ continue
634+ new_params [key ] = val
635+
636+ # update dict
637+ self .update (new_params , overwrite = True )
638+ if add_unknown_items :
639+ additional_params = self ._add_unknown_args (unknown )
640+ self .update (additional_params )
641+ new_params .update (additional_params )
642+ return new_params
643+
516644
517645class LookupConfig (Config ):
518646 """
0 commit comments