55 SingleThreadedAugmenter , SlimDataLoaderBase
66from batchgenerators .transforms import AbstractTransform
77
8+ from multiprocessing import Queue
9+ from queue import Full
10+
811from delira import get_current_debug_mode
912from .data_loader import BaseDataLoader
1013from .dataset import AbstractDataset , BaseCacheDataset , BaseLazyDataset
@@ -24,8 +27,8 @@ class Augmenter(object):
2427 """
2528
2629 def __init__ (self , data_loader : BaseDataLoader , transforms ,
27- n_process_augmentation = None , num_cached_per_queue = 2 ,
28- seeds = None , ** kwargs ):
30+ n_process_augmentation , sampler , sampler_queues : list ,
31+ num_cached_per_queue = 2 , seeds = None , ** kwargs ):
2932 """
3033
3134 Parameters
@@ -37,6 +40,12 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
3740 n_process_augmentation : int
3841 the number of processes to use for augmentation (only necessary if
3942 not in debug mode)
43+ sampler : :class:`AbstractSampler`
44+ the sampler to use; must be used here instead of inside the
45+ dataloader to avoid duplications and oversampling due to
46+ multiprocessing
47+ sampler_queues : list of :class:`multiprocessing.Queue`
48+ queues to pass the sample indices to the actual dataloader
4049 num_cached_per_queue : int
4150 the number of samples to cache per queue (only necessary if not in
4251 debug mode)
@@ -45,6 +54,9 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
4554 **kwargs :
4655 additional keyword arguments
4756 """
57+
58+ self ._batchsize = data_loader .batch_size
59+
4860 # don't use multiprocessing in debug mode
4961 if get_current_debug_mode ():
5062 augmenter = SingleThreadedAugmenter (data_loader , transforms )
@@ -72,42 +84,59 @@ def __init__(self, data_loader: BaseDataLoader, transforms,
7284 ** kwargs )
7385
7486 self ._augmenter = augmenter
87+ self ._sampler = sampler
88+ self ._sampler_queues = sampler_queues
89+ self ._queue_id = 0
7590
76- @property
7791 def __iter__ (self ):
7892 """
79- Property returning the augmenters ``__iter__``
93+ Function returning an iterator
8094
8195 Returns
8296 -------
83- Callable
84- the augmenters ``__iter__``
97+ Augmenter
98+ self
8599 """
86- return self ._augmenter .__iter__
100+ return self
101+
102+ def _next_queue (self ):
103+ idx = self ._queue_id
104+ self ._queue_id = (self ._queue_id + 1 ) % len (self ._sampler_queues )
105+ return self ._sampler_queues [idx ]
87106
88- @property
89107 def __next__ (self ):
90108 """
91- Property returning the augmenters ``__next__``
109+ Function to sample and load the next batch
92110
93111 Returns
94112 -------
95- Callable
96- the augmenters ``__next__``
113+ dict
114+ the next batch
97115 """
98- return self ._augmenter .__next__
116+ idxs = self ._sampler (self ._batchsize )
117+ queue = self ._next_queue ()
118+
119+ # dont't wait forever. Release this after short timeout and try again
120+ # to avoid deadlock
121+ while True :
122+ try :
123+ queue .put (idxs , timeout = 0.2 )
124+ break
125+ except Full :
126+ continue
127+
128+ return next (self ._augmenter )
99129
100- @property
101130 def next (self ):
102131 """
103- Property returning the augmenters ``next``
132+ Function to sample and load
104133
105134 Returns
106135 -------
107- Callable
108- the augmenters `` next``
136+ dict
137+ the next batch
109138 """
110- return self . _augmenter . next
139+ return next ( self )
111140
112141 @staticmethod
113142 def __identity_fn (* args , ** kwargs ):
@@ -179,7 +208,6 @@ def restart(self):
179208 """
180209 return self ._fn_checker ("restart" )
181210
182- @property
183211 def _finish (self ):
184212 """
185213 Property to provide uniform API of ``_finish``
@@ -190,7 +218,12 @@ def _finish(self):
190218 either the augmenter's ``_finish`` method (if available) or
191219 ``__identity_fn`` (if not available)
192220 """
193- return self ._fn_checker ("_finish" )
221+ ret_val = self ._fn_checker ("_finish" )()
222+ for queue in self ._sampler_queues :
223+ queue .close ()
224+ queue .join_thread ()
225+
226+ return ret_val
194227
195228 @property
196229 def num_batches (self ):
@@ -230,6 +263,7 @@ def __del__(self):
230263 Function defining what to do, if object should be deleted
231264
232265 """
266+ self ._finish ()
233267 del self ._augmenter
234268
235269
@@ -360,15 +394,23 @@ def get_batchgen(self, seed=1):
360394 """
361395 assert self .n_batches > 0
362396
363- data_loader = self .data_loader_cls (self .dataset ,
364- batch_size = self .batch_size ,
365- num_batches = self .n_batches ,
366- seed = seed ,
367- sampler = self .sampler
368- )
397+ sampler_queues = []
398+
399+ for idx in range (self .n_process_augmentation ):
400+ sampler_queues .append (Queue ())
401+
402+ data_loader = self .data_loader_cls (
403+ self .dataset ,
404+ batch_size = self .batch_size ,
405+ num_batches = self .n_batches ,
406+ seed = seed ,
407+ sampler_queues = sampler_queues
408+ )
369409
370410 return Augmenter (data_loader , self .transforms ,
371411 self .n_process_augmentation ,
412+ sampler = self .sampler ,
413+ sampler_queues = sampler_queues ,
372414 num_cached_per_queue = 2 ,
373415 seeds = self .n_process_augmentation * [seed ])
374416
@@ -528,6 +570,8 @@ def n_process_augmentation(self):
528570 number of augmentation processes
529571 """
530572
573+ if get_current_debug_mode ():
574+ return 1
531575 return self ._n_process_augmentation
532576
533577 @n_process_augmentation .setter
0 commit comments