Skip to content
This repository was archived by the owner on Jun 26, 2021. It is now read-only.

Commit 71368ce

Browse files
authored
Merge pull request #145 from justusschock/bug_fixes
Bug fixes
2 parents 3c16d7e + 600bd04 commit 71368ce

File tree

9 files changed

+145
-83
lines changed

9 files changed

+145
-83
lines changed

delira/data_loading/data_loader.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import numpy as np
22
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
3+
from queue import Empty
4+
import logging
5+
6+
logger = logging.getLogger(__name__)
37

48
from .dataset import AbstractDataset
5-
from .sampler import AbstractSampler, SequentialSampler
69

710

811
class BaseDataLoader(SlimDataLoaderBase):
@@ -12,8 +15,8 @@ class BaseDataLoader(SlimDataLoaderBase):
1215
"""
1316

1417
def __init__(self, dataset: AbstractDataset,
15-
batch_size=1, num_batches=None, seed=1,
16-
sampler=None):
18+
sampler_queues: list,
19+
batch_size=1, num_batches=None, seed=1):
1720
"""
1821
1922
Parameters
@@ -22,13 +25,13 @@ def __init__(self, dataset: AbstractDataset,
2225
dataset to perform sample loading
2326
batch_size : int
2427
number of samples per batch
28+
sampler_queues : list of :class:`multiprocessing.Queue`
29+
the queue,s the sample indices to load will be put to.
30+
Necessary for interprocess communication
2531
num_batches : int
2632
number of batches to load
2733
seed : int
2834
seed for Random Number Generator
29-
sampler : AbstractSampler or None
30-
class defining the sampling strategy;
31-
if None: SequentialSampler will be used
3235
3336
Raises
3437
------
@@ -45,24 +48,16 @@ class defining the sampling strategy;
4548
# store dataset in self._data
4649
super().__init__(dataset, batch_size)
4750

48-
assert isinstance(sampler, AbstractSampler) or sampler is None, \
49-
"Sampler must be instance of subclass of AbstractSampler of None"
50-
51-
if sampler is None:
52-
sampler = SequentialSampler(list(range(len(dataset))))
53-
54-
self.sampler = sampler
51+
self.sampler_queues = sampler_queues
5552

56-
self.n_samples = len(sampler)
53+
self.n_samples = len(dataset)
5754
if num_batches is None:
58-
num_batches = len(sampler) // batch_size
55+
num_batches = len(dataset) // batch_size
5956

6057
self.num_batches = num_batches
6158
self._seed = seed
6259
np.random.seed(seed)
6360

64-
self._batches_generated = 0
65-
6661
def generate_train_batch(self):
6762
"""
6863
Generate Indices which behavior based on self.sampling gets data based
@@ -79,30 +74,31 @@ def generate_train_batch(self):
7974
If the maximum number of batches has been generated
8075
"""
8176

82-
if self._batches_generated >= self.num_batches:
83-
raise StopIteration
84-
else:
85-
self._batches_generated += 1
86-
87-
idxs = self.sampler(self.batch_size)
77+
idxs = None
78+
sampler_queue = self.sampler_queues[self.thread_id]
79+
while idxs is None:
80+
try:
81+
idxs = sampler_queue.get(timeout=0.2)
8882

89-
result = [self._get_sample(_idx) for _idx in idxs]
83+
result = [self._get_sample(_idx) for _idx in idxs]
9084

91-
result_dict = {}
85+
result_dict = {}
9286

93-
# concatenate dict entities by keys
94-
for _result_dict in result:
95-
for key, val in _result_dict.items():
96-
if key in result_dict.keys():
97-
result_dict[key].append(val)
98-
else:
99-
result_dict[key] = [val]
87+
# concatenate dict entities by keys
88+
for _result_dict in result:
89+
for key, val in _result_dict.items():
90+
if key in result_dict.keys():
91+
result_dict[key].append(val)
92+
else:
93+
result_dict[key] = [val]
10094

101-
# convert list to numpy arrays
102-
for key, val_list in result_dict.items():
103-
result_dict[key] = np.asarray(val_list)
95+
# convert list to numpy arrays
96+
for key, val_list in result_dict.items():
97+
result_dict[key] = np.asarray(val_list)
10498

105-
return result_dict
99+
return result_dict
100+
except Empty:
101+
pass
106102

107103
def _get_sample(self, index):
108104
"""

delira/data_loading/data_manager.py

Lines changed: 69 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
SingleThreadedAugmenter, SlimDataLoaderBase
66
from batchgenerators.transforms import AbstractTransform
77

8+
from multiprocessing import Queue
9+
from queue import Full
10+
811
from delira import get_current_debug_mode
912
from .data_loader import BaseDataLoader
1013
from .dataset import AbstractDataset, BaseCacheDataset, BaseLazyDataset
@@ -24,8 +27,8 @@ class Augmenter(object):
2427
"""
2528

2629
def __init__(self, data_loader: BaseDataLoader, transforms,
27-
n_process_augmentation=None, num_cached_per_queue=2,
28-
seeds=None, **kwargs):
30+
n_process_augmentation, sampler, sampler_queues: list,
31+
num_cached_per_queue=2, seeds=None, **kwargs):
2932
"""
3033
3134
Parameters
@@ -37,6 +40,12 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
3740
n_process_augmentation : int
3841
the number of processes to use for augmentation (only necessary if
3942
not in debug mode)
43+
sampler : :class:`AbstractSampler`
44+
the sampler to use; must be used here instead of inside the
45+
dataloader to avoid duplications and oversampling due to
46+
multiprocessing
47+
sampler_queues : list of :class:`multiprocessing.Queue`
48+
queues to pass the sample indices to the actual dataloader
4049
num_cached_per_queue : int
4150
the number of samples to cache per queue (only necessary if not in
4251
debug mode)
@@ -45,6 +54,9 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
4554
**kwargs :
4655
additional keyword arguments
4756
"""
57+
58+
self._batchsize = data_loader.batch_size
59+
4860
# don't use multiprocessing in debug mode
4961
if get_current_debug_mode():
5062
augmenter = SingleThreadedAugmenter(data_loader, transforms)
@@ -72,42 +84,59 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
7284
**kwargs)
7385

7486
self._augmenter = augmenter
87+
self._sampler = sampler
88+
self._sampler_queues = sampler_queues
89+
self._queue_id = 0
7590

76-
@property
7791
def __iter__(self):
7892
"""
79-
Property returning the augmenters ``__iter__``
93+
Function returning an iterator
8094
8195
Returns
8296
-------
83-
Callable
84-
the augmenters ``__iter__``
97+
Augmenter
98+
self
8599
"""
86-
return self._augmenter.__iter__
100+
return self
101+
102+
def _next_queue(self):
103+
idx = self._queue_id
104+
self._queue_id = (self._queue_id + 1) % len(self._sampler_queues)
105+
return self._sampler_queues[idx]
87106

88-
@property
89107
def __next__(self):
90108
"""
91-
Property returning the augmenters ``__next__``
109+
Function to sample and load the next batch
92110
93111
Returns
94112
-------
95-
Callable
96-
the augmenters ``__next__``
113+
dict
114+
the next batch
97115
"""
98-
return self._augmenter.__next__
116+
idxs = self._sampler(self._batchsize)
117+
queue = self._next_queue()
118+
119+
# dont't wait forever. Release this after short timeout and try again
120+
# to avoid deadlock
121+
while True:
122+
try:
123+
queue.put(idxs, timeout=0.2)
124+
break
125+
except Full:
126+
continue
127+
128+
return next(self._augmenter)
99129

100-
@property
101130
def next(self):
102131
"""
103-
Property returning the augmenters ``next``
132+
Function to sample and load
104133
105134
Returns
106135
-------
107-
Callable
108-
the augmenters ``next``
136+
dict
137+
the next batch
109138
"""
110-
return self._augmenter.next
139+
return next(self)
111140

112141
@staticmethod
113142
def __identity_fn(*args, **kwargs):
@@ -179,7 +208,6 @@ def restart(self):
179208
"""
180209
return self._fn_checker("restart")
181210

182-
@property
183211
def _finish(self):
184212
"""
185213
Property to provide uniform API of ``_finish``
@@ -190,7 +218,12 @@ def _finish(self):
190218
either the augmenter's ``_finish`` method (if available) or
191219
``__identity_fn`` (if not available)
192220
"""
193-
return self._fn_checker("_finish")
221+
ret_val = self._fn_checker("_finish")()
222+
for queue in self._sampler_queues:
223+
queue.close()
224+
queue.join_thread()
225+
226+
return ret_val
194227

195228
@property
196229
def num_batches(self):
@@ -230,6 +263,7 @@ def __del__(self):
230263
Function defining what to do, if object should be deleted
231264
232265
"""
266+
self._finish()
233267
del self._augmenter
234268

235269

@@ -360,15 +394,23 @@ def get_batchgen(self, seed=1):
360394
"""
361395
assert self.n_batches > 0
362396

363-
data_loader = self.data_loader_cls(self.dataset,
364-
batch_size=self.batch_size,
365-
num_batches=self.n_batches,
366-
seed=seed,
367-
sampler=self.sampler
368-
)
397+
sampler_queues = []
398+
399+
for idx in range(self.n_process_augmentation):
400+
sampler_queues.append(Queue())
401+
402+
data_loader = self.data_loader_cls(
403+
self.dataset,
404+
batch_size=self.batch_size,
405+
num_batches=self.n_batches,
406+
seed=seed,
407+
sampler_queues=sampler_queues
408+
)
369409

370410
return Augmenter(data_loader, self.transforms,
371411
self.n_process_augmentation,
412+
sampler=self.sampler,
413+
sampler_queues=sampler_queues,
372414
num_cached_per_queue=2,
373415
seeds=self.n_process_augmentation * [seed])
374416

@@ -528,6 +570,8 @@ def n_process_augmentation(self):
528570
number of augmentation processes
529571
"""
530572

573+
if get_current_debug_mode():
574+
return 1
531575
return self._n_process_augmentation
532576

533577
@n_process_augmentation.setter

delira/training/base_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def reduce_fn(batch):
454454
if "val_" + val_score_key not in total_metrics:
455455
logger.warning(
456456
"val_score_key '%s' not a valid key for \
457-
validation metrics ")
457+
validation metrics" % str(val_score_key))
458458

459459
new_val_score = best_val_score
460460

delira/training/experiment.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from datetime import datetime
77
from functools import partial
8+
import copy
89

910
import numpy as np
1011
from sklearn.model_selection import KFold, StratifiedKFold, \
@@ -547,8 +548,8 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
547548
train_data = data.get_subset(train_idxs)
548549
test_data = data.get_subset(test_idxs)
549550

550-
train_data.update_state_from_dict(train_kwargs)
551-
test_data.update_state_from_dict(test_kwargs)
551+
train_data.update_state_from_dict(copy.deepcopy(train_kwargs))
552+
test_data.update_state_from_dict(copy.deepcopy(test_kwargs))
552553

553554
val_data = None
554555
if val_split is not None:
@@ -572,7 +573,7 @@ def kfold(self, data: BaseDataManager, metrics: dict, num_epochs=None,
572573
for _train_idxs, _val_idxs in _val_split.split(train_idxs,
573574
train_labels):
574575
val_data = train_data.get_subset(_val_idxs)
575-
val_data.update_state_from_dict(test_kwargs)
576+
val_data.update_state_from_dict(copy.deepcopy(test_kwargs))
576577

577578
train_data = train_data.get_subset(_train_idxs)
578579

0 commit comments

Comments
 (0)