Skip to content
This repository was archived by the owner on Jun 26, 2021. It is now read-only.

Commit 3c16d7e

Browse files
authored
Merge pull request #141 from justusschock/code-cleanup
Remove mutable default arguments
2 parents c849c6a + 2bee4d4 commit 3c16d7e

File tree

17 files changed

+168
-86
lines changed

17 files changed

+168
-86
lines changed

delira/data_loading/data_manager.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import inspect
22
import logging
33

4-
import numpy as np
5-
import typing
6-
import inspect
74
from batchgenerators.dataloading import MultiThreadedAugmenter, \
85
SingleThreadedAugmenter, SlimDataLoaderBase
96
from batchgenerators.transforms import AbstractTransform
107

8+
from delira import get_current_debug_mode
119
from .data_loader import BaseDataLoader
1210
from .dataset import AbstractDataset, BaseCacheDataset, BaseLazyDataset
1311
from .load_utils import default_load_fn_2d
1412
from .sampler import SequentialSampler, AbstractSampler
1513
from ..utils.decorators import make_deprecated
16-
from delira import get_current_debug_mode
1714

1815
logger = logging.getLogger(__name__)
1916

@@ -245,7 +242,7 @@ class BaseDataManager(object):
245242

246243
def __init__(self, data, batch_size, n_process_augmentation,
247244
transforms, sampler_cls=SequentialSampler,
248-
sampler_kwargs={},
245+
sampler_kwargs=None,
249246
data_loader_cls=None, dataset_cls=None,
250247
load_fn=default_load_fn_2d, from_disc=True, **kwargs):
251248
"""
@@ -292,6 +289,8 @@ class defining the sampling strategy
292289
"""
293290

294291
# Instantiate Hidden variables for property access
292+
if sampler_kwargs is None:
293+
sampler_kwargs = {}
295294
self._batch_size = None
296295
self._n_process_augmentation = None
297296
self._transforms = None

delira/data_loading/dataset.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,8 +583,13 @@ class Nii3DLazyDataset(BaseLazyDataset):
583583
"""
584584

585585
@make_deprecated('LoadSample')
586-
def __init__(self, data_path, load_fn, img_extensions, gt_extensions,
587-
img_files, label_file, **load_kwargs):
586+
def __init__(
587+
self,
588+
data_path,
589+
load_fn,
590+
img_files,
591+
label_file,
592+
**load_kwargs):
588593
"""
589594
Parameters
590595
----------
@@ -639,7 +644,7 @@ class Nii3DCacheDatset(BaseCacheDataset):
639644
"""
640645

641646
@make_deprecated('LoadSample')
642-
def __init__(self, data_path, load_fn, img_extensions, gt_extensions,
647+
def __init__(self, data_path, load_fn,
643648
img_files, label_file, **load_kwargs):
644649
"""
645650
Parameters

delira/data_loading/load_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,11 @@ def is_valid_image_file(fname, img_extensions, gt_extensions):
7878
----------
7979
fname : str
8080
filename of image path
81-
Returns
81+
img_extensions : list
82+
list of valid image file extensions
83+
gt_extensions : list
84+
list of valid gt file extensions
85+
Returns
8286
-------
8387
bool
8488
is valid data sample
@@ -148,7 +152,7 @@ class LoadSample:
148152
def __init__(self,
149153
sample_ext: dict,
150154
sample_fn: collections.abc.Callable,
151-
dtype={}, normalize=(), norm_fn=norm_range('-1,1'),
155+
dtype=None, normalize=(), norm_fn=norm_range('-1,1'),
152156
**kwargs):
153157
"""
154158
@@ -185,6 +189,8 @@ def __init__(self,
185189
>>> 'seg': 'uint8'},
186190
>>> normalize=('data',))
187191
"""
192+
if dtype is None:
193+
dtype = {}
188194
self._sample_ext = sample_ext
189195
self._sample_fn = sample_fn
190196
self._dtype = dtype
@@ -235,7 +241,7 @@ def __init__(self,
235241
sample_fn: collections.abc.Callable,
236242
label_ext: collections.abc.Iterable,
237243
label_fn: collections.abc.Callable,
238-
sample_kwargs={}, **kwargs):
244+
sample_kwargs=None, **kwargs):
239245
"""
240246
Load sample and label from folder
241247
@@ -264,6 +270,9 @@ def __init__(self,
264270
--------
265271
:class: `LoadSample`
266272
"""
273+
if sample_kwargs is None:
274+
sample_kwargs = {}
275+
267276
super().__init__(sample_ext, sample_fn, **sample_kwargs)
268277
self._label_ext = label_ext
269278
self._label_fn = label_fn
@@ -282,7 +291,7 @@ def __call__(self, path):
282291
dict
283292
dict with data and label
284293
"""
285-
sample_dict = super(LoadSampleLabel, self).__call__(path)
294+
sample_dict = super().__call__(path)
286295
label_dict = self._label_fn(os.path.join(path, self._label_ext),
287296
**self._label_kwargs)
288297
sample_dict.update(label_dict)

delira/io/torch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from ..models import AbstractPyTorchNetwork
1212

13-
def save_checkpoint(file: str, model=None, optimizers={},
13+
def save_checkpoint(file: str, model=None, optimizers=None,
1414
epoch=None, **kwargs):
1515
"""
1616
Save model's parameters
@@ -28,6 +28,8 @@ def save_checkpoint(file: str, model=None, optimizers={},
2828
current epoch (will also be pickled)
2929
3030
"""
31+
if optimizers is None:
32+
optimizers = {}
3133
if isinstance(model, torch.nn.DataParallel):
3234
_model = model.module
3335
else:

delira/models/abstract_network.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def __call__(self, *args, **kwargs):
5353

5454
@staticmethod
5555
@abc.abstractmethod
56-
def closure(model, data_dict: dict, optimizers: dict, losses={},
57-
metrics={}, fold=0, **kwargs):
56+
def closure(model, data_dict: dict, optimizers: dict, losses=None,
57+
metrics=None, fold=0, **kwargs):
5858
"""
5959
Function which handles prediction from batch, logging, loss calculation
6060
and optimizer step
@@ -90,6 +90,10 @@ def closure(model, data_dict: dict, optimizers: dict, losses={},
9090
If not overwritten by subclass
9191
9292
"""
93+
if losses is None:
94+
losses = {}
95+
if metrics is None:
96+
metrics = {}
9397
raise NotImplementedError()
9498

9599
@staticmethod

delira/models/classification/classification_network.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def forward(self, input_batch: torch.Tensor):
6767

6868
@staticmethod
6969
def closure(model: AbstractPyTorchNetwork, data_dict: dict,
70-
optimizers: dict, losses={}, metrics={},
70+
optimizers: dict, losses=None, metrics=None,
7171
fold=0, **kwargs):
7272
"""
7373
closure method to do a single backpropagation step
@@ -108,6 +108,10 @@ def closure(model: AbstractPyTorchNetwork, data_dict: dict,
108108
109109
"""
110110

111+
if losses is None:
112+
losses = {}
113+
if metrics is None:
114+
metrics = {}
111115
assert (optimizers and losses) or not optimizers, \
112116
"Criterion dict cannot be emtpy, if optimizers are passed"
113117

delira/models/classification/classification_network_tf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _build_model(n_outputs: int, **kwargs):
131131

132132
@staticmethod
133133
def closure(model: typing.Type[AbstractTfNetwork], data_dict: dict,
134-
metrics={}, fold=0, **kwargs):
134+
metrics=None, fold=0, **kwargs):
135135
"""
136136
closure method to do a single prediction.
137137
This is followed by backpropagation or not based state of
@@ -163,9 +163,10 @@ def closure(model: typing.Type[AbstractTfNetwork], data_dict: dict,
163163
164164
"""
165165

166+
if metrics is None:
167+
metrics = {}
166168
loss_vals = {}
167169
metric_vals = {}
168-
image_names = "input_images"
169170

170171
inputs = data_dict.pop('data')
171172

delira/models/gan/generative_adversarial_network.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def forward(self, real_image_batch):
9494

9595
@staticmethod
9696
def closure(model, data_dict: dict,
97-
optimizers: dict, losses={}, metrics={},
97+
optimizers: dict, losses=None, metrics=None,
9898
fold=0, **kwargs):
9999
"""
100100
closure method to do a single backpropagation step
@@ -134,6 +134,10 @@ def closure(model, data_dict: dict,
134134
135135
"""
136136

137+
if losses is None:
138+
losses = {}
139+
if metrics is None:
140+
metrics = {}
137141
loss_vals = {}
138142
metric_vals = {}
139143
total_loss_discr_real = 0

delira/models/segmentation/unet.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
import torch.nn.functional as F
99
from torch.nn import init
10-
import logging
1110
from ..abstract_network import AbstractPyTorchNetwork
1211

1312
class UNet2dPyTorch(AbstractPyTorchNetwork):
@@ -175,8 +174,8 @@ def forward(self, x):
175174
return {"pred": x}
176175

177176
@staticmethod
178-
def closure(model, data_dict: dict, optimizers: dict, losses={},
179-
metrics={}, fold=0, **kwargs):
177+
def closure(model, data_dict: dict, optimizers: dict, losses=None,
178+
metrics=None, fold=0, **kwargs):
180179
"""
181180
closure method to do a single backpropagation step
182181
@@ -216,6 +215,10 @@ def closure(model, data_dict: dict, optimizers: dict, losses={},
216215
217216
"""
218217

218+
if losses is None:
219+
losses = {}
220+
if metrics is None:
221+
metrics = {}
219222
assert (optimizers and losses) or not optimizers, \
220223
"Loss dict cannot be emtpy, if optimizers are passed"
221224

@@ -618,8 +621,8 @@ def forward(self, x):
618621
return {"pred": x}
619622

620623
@staticmethod
621-
def closure(model, data_dict: dict, optimizers: dict, losses={},
622-
metrics={}, fold=0, **kwargs):
624+
def closure(model, data_dict: dict, optimizers: dict, losses=None,
625+
metrics=None, fold=0, **kwargs):
623626
"""
624627
closure method to do a single backpropagation step
625628
@@ -659,6 +662,10 @@ def closure(model, data_dict: dict, optimizers: dict, losses={},
659662
660663
"""
661664

665+
if losses is None:
666+
losses = {}
667+
if metrics is None:
668+
metrics = {}
662669
assert (optimizers and losses) or not optimizers, \
663670
"Loss dict cannot be emtpy, if optimizers are passed"
664671

delira/training/base_trainer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
from abc import abstractmethod
21
import logging
2+
import os
33
import pickle
44
import typing
55

6-
from ..data_loading.data_manager import Augmenter
7-
8-
from .predictor import Predictor
9-
from .callbacks import AbstractCallback
10-
from ..models import AbstractNetwork
11-
126
import numpy as np
13-
import os
147
from tqdm import tqdm
8+
159
from delira.logging import TrixiHandler
10+
from .callbacks import AbstractCallback
11+
from .predictor import Predictor
12+
from ..data_loading.data_manager import Augmenter
13+
from ..models import AbstractNetwork
1614

1715
logger = logging.getLogger(__name__)
1816

@@ -614,7 +612,7 @@ def _update_state(self, new_state):
614612
615613
"""
616614
for key, val in new_state.items():
617-
if (key.startswith("__") and key.endswith("__")):
615+
if key.startswith("__") and key.endswith("__"):
618616
continue
619617

620618
try:
@@ -729,7 +727,7 @@ def _reinitialize_logging(self, logging_type, logging_kwargs: dict):
729727
handlers=new_handlers)
730728

731729
@staticmethod
732-
def _search_for_prev_state(path, extensions=[]):
730+
def _search_for_prev_state(path, extensions=None):
733731
"""
734732
Helper function to search in a given path for previous epoch states
735733
(indicated by extensions)
@@ -752,6 +750,8 @@ def _search_for_prev_state(path, extensions=[]):
752750
the latest epoch (1 if no checkpoint was found)
753751
754752
"""
753+
if extensions is None:
754+
extensions = []
755755
files = []
756756
for file in os.listdir(path):
757757
for ext in extensions:

0 commit comments

Comments
 (0)