Skip to content
This repository was archived by the owner on Jun 26, 2021. It is now read-only.

Commit be0f9f0

Browse files
authored
Merge pull request #231 from delira-dev/update_from_sys_arg
Update Config from system args
2 parents 64ce57b + c0e5ac8 commit be0f9f0

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed

delira/utils/config.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import yaml
99
import argparse
10+
import sys
11+
import collections
12+
import inspect
1013

1114

1215
def non_string_warning(func):
@@ -32,6 +35,7 @@ def warning_wrapper(config, key, *args, **kwargs):
3235
key, type(key)), RuntimeWarning)
3336

3437
return func(config, key, *args, **kwargs)
38+
3539
return warning_wrapper
3640

3741

@@ -513,6 +517,130 @@ def create_from_str(cls, data, formatter=yaml.load, decoder_cls=Decoder,
513517
**kwargs)
514518
return config
515519

520+
def create_argparser(self):
521+
'''
522+
Creates an argparser for all values in the config
523+
Following the pattern: `--training.learning_rate 1234`
524+
525+
Returns
526+
-------
527+
argparse.ArgumentParser
528+
parser for all variables in the config
529+
'''
530+
parser = argparse.ArgumentParser(allow_abbrev=False)
531+
532+
def add_val(dict_like, prefix=''):
533+
for key, val in dict_like.items():
534+
name = "--{}".format(prefix + key)
535+
if val is None:
536+
parser.add_argument(name)
537+
else:
538+
if isinstance(val, int):
539+
parser.add_argument(name, type=type(val))
540+
elif isinstance(val, collections.Mapping):
541+
add_val(val, prefix=key + '.')
542+
elif isinstance(val, collections.Iterable):
543+
if len(val) > 0 and type(val[0]) != type:
544+
parser.add_argument(name, type=type(val[0]))
545+
else:
546+
parser.add_argument(name)
547+
elif issubclass(val, type) or inspect.isclass(val):
548+
parser.add_argument(name, type=val)
549+
else:
550+
parser.add_argument(name, type=type(val))
551+
552+
add_val(self)
553+
return parser
554+
555+
@staticmethod
556+
def _add_unknown_args(unknown_args):
557+
'''
558+
Can add unknown args as parsed by argparsers method
559+
`parse_unknown_args`.
560+
561+
Parameters
562+
------
563+
unknown_args : list
564+
list of unknown args
565+
Returns
566+
------
567+
Config
568+
a config of the parsed args
569+
'''
570+
# first element in the list must be a key
571+
if not isinstance(unknown_args[0], str):
572+
unknown_args = [str(arg) for arg in unknown_args]
573+
if not unknown_args[0].startswith('--'):
574+
raise ValueError
575+
576+
args = Config()
577+
# take first key
578+
key = unknown_args[0][2:]
579+
idx, done, val = 1, False, []
580+
while not done:
581+
try:
582+
item = unknown_args[idx]
583+
except IndexError:
584+
done = True
585+
if item.startswith('--') or done:
586+
# save key with its value
587+
if len(val) == 0:
588+
# key is used as flag
589+
args[key] = True
590+
elif len(val) == 1:
591+
args[key] = val[0]
592+
else:
593+
args[key] = val
594+
# new key and flush data
595+
key = item[2:]
596+
val = []
597+
else:
598+
val.append(item)
599+
idx += 1
600+
return args
601+
602+
def update_from_argparse(self, parser=None, add_unknown_items=False):
603+
'''
604+
Updates the config with all values from the command line.
605+
Following the pattern: `--training.learning_rate 1234`
606+
607+
Raises
608+
------
609+
TypeError
610+
raised if another datatype than currently in the config is parsed
611+
Returns
612+
-------
613+
dict
614+
dictionary containing only updated arguments
615+
'''
616+
617+
if len(sys.argv) > 1:
618+
if not parser:
619+
parser = self.create_argparser()
620+
621+
params, unknown = parser.parse_known_args()
622+
params = vars(params)
623+
if unknown and not add_unknown_items:
624+
warnings.warn(
625+
"Called with unknown arguments: {} "
626+
"They will not be stored if you do not set "
627+
"`add_unknown_items` to true.".format(unknown),
628+
RuntimeWarning)
629+
630+
new_params = Config()
631+
for key, val in params.items():
632+
if val is None:
633+
continue
634+
new_params[key] = val
635+
636+
# update dict
637+
self.update(new_params, overwrite=True)
638+
if add_unknown_items:
639+
additional_params = self._add_unknown_args(unknown)
640+
self.update(additional_params)
641+
new_params.update(additional_params)
642+
return new_params
643+
516644

517645
class LookupConfig(Config):
518646
"""

tests/utils/test_config.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import unittest
22
import os
3+
import sys
34
import copy
45
import argparse
6+
from unittest.mock import patch
57
from delira._version import get_versions
68

79
from delira.utils.config import Config, LookupConfig, DeliraConfig
810
from delira.logging import Logger, TensorboardBackend, make_logger, \
911
register_logger
12+
import warnings
1013

1114
from . import check_for_no_backend
1215

@@ -215,6 +218,48 @@ def test_internal_type(self):
215218
cf = self.config_cls.create_from_dict(self.example_dict)
216219
self.assertTrue(isinstance(cf["deep"], self.config_cls))
217220

221+
@unittest.skipUnless(
222+
check_for_no_backend(),
223+
"Test should only be executed if no backend is specified")
224+
def test_create_argparser(self):
225+
cf = self.config_cls.create_from_dict(self.example_dict)
226+
testargs = [
227+
'--shallowNum',
228+
'10',
229+
'--deep.deepStr',
230+
'check',
231+
'--testlist',
232+
'ele1',
233+
'ele2',
234+
'--setflag']
235+
parser = cf.create_argparser()
236+
known, unknown = parser.parse_known_args(testargs)
237+
self.assertEqual(vars(known)['shallowNum'], 10)
238+
self.assertEqual(vars(known)['deep.deepStr'], 'check')
239+
self.assertEqual(unknown, ['--testlist', 'ele1', 'ele2', '--setflag'])
240+
241+
@unittest.skipUnless(
242+
check_for_no_backend(),
243+
"Test should only be executed if no backend is specified")
244+
def test_update_from_argparse(self):
245+
cf = self.config_cls.create_from_dict(self.example_dict)
246+
testargs = ['--shallowNum', '10',
247+
'--deep.deepStr', 'check',
248+
'--testlist', 'ele1', 'ele2',
249+
'--setflag']
250+
# placeholder pyfile because argparser omits first argument from sys
251+
# argv
252+
with patch.object(sys, 'argv', ['pyfile.py'] + testargs):
253+
cf.update_from_argparse(add_unknown_items=True)
254+
self.assertEqual(cf['shallowNum'], int(testargs[1]))
255+
self.assertEqual(cf['deep']['deepStr'], testargs[3])
256+
self.assertEqual(cf['testlist'], testargs[5:7])
257+
self.assertEqual(cf['setflag'], True)
258+
with warnings.catch_warnings(record=True) as w:
259+
with patch.object(sys, 'argv', ['pyfile.py', '--unknown', 'arg']):
260+
cf.update_from_argparse(add_unknown_items=False)
261+
self.assertEqual(len(w), 1)
262+
218263

219264
class LookupConfigTest(ConfigTest):
220265
def setUp(self):

0 commit comments

Comments
 (0)