Skip to content
This repository was archived by the owner on Jun 26, 2021. It is now read-only.

Commit c0e5ac8

Browse files
authored
Merge branch 'master' into update_from_sys_arg
2 parents 2efaecb + 64ce57b commit c0e5ac8

File tree

6 files changed

+136
-8
lines changed

6 files changed

+136
-8
lines changed

delira/logging/base_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def _log_item(self):
151151
152152
"""
153153
# get item from dict
154-
process_item = self._queue.get(timeout=0.5)
154+
process_item = self._queue.get(timeout=0.001)
155155
# log item if item is dict
156156
if isinstance(process_item, dict):
157157

delira/logging/base_logger.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from multiprocessing import Queue, Event
2-
from queue import Full
1+
from multiprocessing.queues import Queue as MpQueue
2+
from threading import Event
3+
from queue import Queue, Full
34
from delira.logging.base_backend import BaseBackend
45
from delira.utils.dict_reductions import get_reduction, possible_reductions, \
56
reduce_dict
@@ -231,8 +232,9 @@ def close(self):
231232
232233
"""
233234
if hasattr(self, "_flush_queue"):
234-
self._flush_queue.close()
235-
self._flush_queue.join_thread()
235+
if isinstance(self._flush_queue, MpQueue):
236+
self._flush_queue.close()
237+
self._flush_queue.join_thread()
236238

237239
if hasattr(self, "abort_event"):
238240
self._abort_event.set()

delira/training/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@
2020
ReduceLROnPlateauCallback as ReduceLROnPlateauCallbackPyTorch
2121
from delira.training.callbacks.pytorch_schedulers import StepLRCallback \
2222
as StepLRCallbackPyTorch
23+
from delira.training.callbacks.pytorch_schedulers import \
24+
OneCycleLRCallback as OneCycleLRCallbackPyTorch

delira/training/callbacks/pytorch_schedulers.py

Lines changed: 121 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
if 'TORCH' in get_backends():
55
from torch.optim.lr_scheduler import ReduceLROnPlateau, \
6-
CosineAnnealingLR, ExponentialLR, LambdaLR, MultiStepLR, StepLR
6+
CosineAnnealingLR, ExponentialLR, LambdaLR, MultiStepLR, StepLR, \
7+
OneCycleLR
78

89
class DefaultPyTorchSchedulerCallback(AbstractCallback):
910
"""
@@ -47,6 +48,125 @@ def at_epoch_end(self, trainer, **kwargs):
4748
self.scheduler.step(epoch=kwargs.get("curr_epoch", None))
4849
return {}
4950

51+
class OneCycleLRCallback(DefaultPyTorchSchedulerCallback):
52+
"""
53+
Wraps PyTorch's `OneCycleLR` Scheduler as Callback
54+
55+
"""
56+
57+
def __init__(
58+
self,
59+
optimizer,
60+
max_lr,
61+
total_steps=None,
62+
epochs=None,
63+
steps_per_epoch=None,
64+
pct_start=0.3,
65+
anneal_strategy='cos',
66+
cycle_momentum=True,
67+
base_momentum=0.85,
68+
max_momentum=0.95,
69+
div_factor=25.0,
70+
final_div_factor=10000.0,
71+
last_epoch=-1):
72+
"""
73+
74+
Parameters
75+
----------
76+
optimizer (Optimizer): Wrapped optimizer.
77+
max_lr (float or list): Upper learning rate boundaries in the cycle
78+
for each parameter group.
79+
total_steps (int): The total number of steps in the cycle. Note
80+
that if a value is provided here, then it must be inferred by
81+
providing a value for epochs and steps_per_epoch.
82+
Default: None
83+
epochs (int): The number of epochs to train for. This is used along
84+
with steps_per_epoch in order to infer the total number of
85+
steps in the cycle if a value for total_steps is not provided.
86+
Default: None
87+
steps_per_epoch (int): The number of steps per epoch to train for.
88+
This is used along with epochs in order to infer the total
89+
number of steps in the cycle if a value for total_steps is
90+
not provided.
91+
Default: None
92+
pct_start (float): The percentage of the cycle (in number of steps)
93+
spent increasing the learning rate.
94+
Default: 0.3
95+
anneal_strategy (str): {'cos', 'linear'}
96+
Specifies the annealing strategy.
97+
Default: 'cos'
98+
cycle_momentum (bool): If ``True``, momentum is cycled inversely
99+
to learning rate between 'base_momentum' and 'max_momentum'.
100+
Default: True
101+
base_momentum (float or list): Lower momentum boundaries in the
102+
cycle for each parameter group. Note that momentum is cycled
103+
inversely to learning rate; at the peak of a cycle, momentum is
104+
'base_momentum' and learning rate is 'max_lr'.
105+
Default: 0.85
106+
max_momentum (float or list): Upper momentum boundaries in the
107+
cycle for each parameter group. Functionally,
108+
it defines the cycle amplitude (max_momentum - base_momentum).
109+
Note that momentum is cycled inversely
110+
to learning rate; at the start of a cycle, momentum is
111+
'max_momentum' and learning rate is 'base_lr'
112+
Default: 0.95
113+
div_factor (float): Determines the initial learning rate via
114+
initial_lr = max_lr/div_factor
115+
Default: 25
116+
final_div_factor (float): Determines the minimum learning rate via
117+
min_lr = initial_lr/final_div_factor
118+
Default: 1e4
119+
last_epoch (int): The index of the last batch. This parameter is
120+
used when resuming a training job. Since `step()` should be
121+
invoked after each batch instead of after each epoch, this
122+
number represents the total number of *batches* computed,
123+
not the total number of epochs computed.
124+
When last_epoch=-1, the schedule is started from the
125+
beginning.
126+
Default: -1
127+
"""
128+
super().__init__()
129+
self.scheduler = OneCycleLR(
130+
optimizer,
131+
max_lr,
132+
total_steps,
133+
epochs,
134+
steps_per_epoch,
135+
pct_start,
136+
anneal_strategy,
137+
cycle_momentum,
138+
base_momentum,
139+
max_momentum,
140+
div_factor,
141+
final_div_factor,
142+
last_epoch)
143+
144+
def at_iter_begin(self, trainer, train,
145+
**kwargs):
146+
"""
147+
Executes a single scheduling step
148+
149+
Parameters
150+
----------
151+
trainer : :class:`PyTorchNetworkTrainer`
152+
the trainer class, which can be changed
153+
kwargs :
154+
additional keyword arguments
155+
156+
Returns
157+
-------
158+
:class:`PyTorchNetworkTrainer`
159+
modified trainer
160+
161+
"""
162+
if train:
163+
self.scheduler.step()
164+
165+
return {}
166+
167+
def at_epoch_end(self, trainer, **kwargs):
168+
return {}
169+
50170
class ReduceLROnPlateauCallback(DefaultPyTorchSchedulerCallback):
51171
"""
52172
Wraps PyTorch's `ReduceLROnPlateau` Scheduler as Callback

delira/training/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def predict_data_mgr(
275275
batch_list = []
276276

277277
for i, batch in iterable:
278-
self._at_iter_begin(iter_num=i)
278+
Predictor._at_iter_begin(self, iter_num=i)
279279

280280
if not batch_list and (n_batches - i) < batchsize:
281281
batchsize = n_batches - i

delira/utils/dict_reductions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,11 @@ def reduce_dict(items: list, reduce_fn) -> dict:
205205

206206
for k, v in result_dict.items():
207207
# check if all items are equal
208-
if all([_v == v[0] for _v in v[1:]]):
208+
equals = [_v == v[0] for _v in v[1:]]
209+
for idx, equality in enumerate(equals):
210+
if isinstance(equality, np.ndarray):
211+
equals[idx] = equality.all()
212+
if all(equals):
209213
# use first item since they are equal
210214
result_dict[k] = v[0]
211215
else:

0 commit comments

Comments
 (0)