Skip to content
This repository was archived by the owner on Jun 26, 2021. It is now read-only.

Commit df0a7e4

Browse files
Merge pull request #215 from delira-dev/register_logger
Register logger
2 parents 47d2a14 + 84432a4 commit df0a7e4

File tree

5 files changed

+107
-9
lines changed

5 files changed

+107
-9
lines changed

delira/training/base_trainer.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .predictor import Predictor
1414
from ..data_loading import Augmenter, DataManager
1515
from ..models import AbstractNetwork
16+
from ..logging import register_logger, make_logger
1617

1718
logger = logging.getLogger(__name__)
1819

@@ -322,8 +323,11 @@ def _at_iter_begin(self, iter_num, epoch=0, **kwargs):
322323
"""
323324
for cb in self._callbacks:
324325
self._update_state(cb.at_iter_begin(
325-
self, iter_num=iter_num, curr_epoch=epoch,
326-
global_iter_num=self._global_iter_num, **kwargs,
326+
self, iter_num=iter_num,
327+
curr_epoch=epoch,
328+
global_iter_num=self._global_iter_num,
329+
train=True,
330+
**kwargs,
327331
))
328332

329333
def _at_iter_end(self, iter_num, data_dict, metrics, epoch=0, **kwargs):
@@ -347,9 +351,12 @@ def _at_iter_end(self, iter_num, data_dict, metrics, epoch=0, **kwargs):
347351

348352
for cb in self._callbacks:
349353
self._update_state(cb.at_iter_end(
350-
self, iter_num=iter_num, data_dict=data_dict,
351-
metrics=metrics, curr_epoch=epoch,
354+
self, iter_num=iter_num,
355+
data_dict=data_dict,
356+
metrics=metrics,
357+
curr_epoch=epoch,
352358
global_iter_num=self._global_iter_num,
359+
train=True,
353360
**kwargs,
354361
))
355362

@@ -833,12 +840,16 @@ def _reinitialize_logging(self, logging_type, logging_kwargs: dict,
833840

834841
level = _logging_kwargs.pop("level")
835842

843+
logger = backend_cls(_logging_kwargs)
844+
836845
self.register_callback(
837846
logging_callback_cls(
838-
backend_cls(logging_kwargs), level=level,
847+
logger, level=level,
839848
logging_frequencies=logging_frequencies,
840849
reduce_types=reduce_types))
841850

851+
register_logger(self._callbacks[-1]._logger, self.name)
852+
842853
@staticmethod
843854
def _search_for_prev_state(path, extensions=None):
844855
"""

delira/training/callbacks/abstract_callback.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def __init__(self, *args, **kwargs):
2020
keyword arguments
2121
2222
"""
23-
pass
23+
super().__init__(*args, **kwargs)
2424

2525
def at_epoch_begin(self, trainer, *args, **kwargs):
2626
"""
@@ -124,7 +124,7 @@ def at_iter_begin(self, trainer, *args, **kwargs):
124124
Notes
125125
-----
126126
The predictor calls the callbacks with the following additional
127-
arguments: `iter_num`(int)
127+
arguments: `iter_num`(int), `train`(bool)
128128
129129
The basetrainer adds following arguments (wrt the predictor):
130130
`curr_epoch`(int), `global_iter_num`(int)
@@ -153,6 +153,7 @@ def at_iter_end(self, trainer, *args, **kwargs):
153153
The predictor calls the callbacks with the following additional
154154
arguments: `iter_num`(int), `metrics`(dict),
155155
`data_dict`(dict, contains prediction and input data),
156+
`train`(bool)
156157
157158
The basetrainer adds following arguments (wrt the predictor):
158159
`curr_epoch`(int), `global_iter_num`(int)

delira/training/callbacks/logging_callback.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def __init__(self, backend: BaseBackend, max_queue_size: int = None,
5050
logging_frequencies=logging_frequencies,
5151
reduce_types=reduce_types, level=level)
5252

53-
def at_iter_end(self, trainer, iter_num=None, data_dict=None, **kwargs):
53+
def at_iter_end(self, trainer, iter_num=None, data_dict=None, train=False,
54+
**kwargs):
5455
"""
5556
Function logging the metrics at the end of each iteration
5657
@@ -63,6 +64,8 @@ def at_iter_end(self, trainer, iter_num=None, data_dict=None, **kwargs):
6364
(unused in this callback)
6465
data_dict : dict
6566
the current data dict (including predictions)
67+
train: bool
68+
signals if callback is called by trainer or predictor
6669
**kwargs :
6770
additional keyword arguments
6871
@@ -76,7 +79,14 @@ def at_iter_end(self, trainer, iter_num=None, data_dict=None, **kwargs):
7679
global_step = kwargs.get("global_iter_num", None)
7780

7881
for k, v in metrics.items():
79-
self._logger.log({"scalar": {"tag": k, "scalar_value": v,
82+
self._logger.log({"scalar": {"tag": self.create_tag(k, train),
83+
"scalar_value": v,
8084
"global_step": global_step}})
8185

8286
return {}
87+
88+
@staticmethod
89+
def create_tag(tag: str, train: bool):
90+
if train:
91+
tag = tag + "_val"
92+
return tag

delira/training/predictor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def _at_iter_begin(self, iter_num, **kwargs):
177177
for cb in self._callbacks:
178178
return_dict.update(cb.at_iter_begin(self,
179179
iter_num=iter_num,
180+
train=False,
180181
**kwargs))
181182

182183
return return_dict
@@ -208,6 +209,7 @@ def _at_iter_end(self, iter_num, data_dict, metrics, **kwargs):
208209
iter_num=iter_num,
209210
data_dict=data_dict,
210211
metrics=metrics,
212+
train=False,
211213
**kwargs))
212214

213215
return return_dict
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import unittest
2+
from delira.logging import log
3+
from delira.training import BaseNetworkTrainer
4+
from delira.models import AbstractNetwork
5+
import os
6+
from tests.utils import check_for_tf_graph_backend
7+
8+
try:
9+
import tensorflow as tf
10+
except ImportError:
11+
tf = None
12+
13+
14+
class LoggingOutsideTrainerTestCase(unittest.TestCase):
15+
16+
@unittest.skipUnless(check_for_tf_graph_backend(),
17+
"TF Backend not installed")
18+
def test_logging_freq(self):
19+
save_path = os.path.abspath("./logs")
20+
config = {
21+
"num_epochs": 2,
22+
"losses": {},
23+
"optimizer_cls": None,
24+
"optimizer_params": {"learning_rate": 1e-3},
25+
"metrics": {},
26+
"lr_scheduler_cls": None,
27+
"lr_scheduler_params": {}
28+
}
29+
trainer = BaseNetworkTrainer(
30+
AbstractNetwork(),
31+
save_path,
32+
**config,
33+
gpu_ids=[],
34+
save_freq=1,
35+
optim_fn=None,
36+
key_mapping={},
37+
logging_type="tensorboardx",
38+
logging_kwargs={
39+
'logdir': save_path
40+
})
41+
42+
trainer._setup(
43+
AbstractNetwork(),
44+
lr_scheduler_cls=None,
45+
lr_scheduler_params={},
46+
gpu_ids=[],
47+
key_mapping={},
48+
convert_batch_to_npy_fn=None,
49+
prepare_batch_fn=None,
50+
callbacks=[])
51+
52+
tag = 'dummy'
53+
54+
log({"scalar": {"scalar_value": 1234, "tag": tag}})
55+
56+
file = [os.path.join(save_path, x)
57+
for x in os.listdir(save_path)
58+
if os.path.isfile(os.path.join(save_path, x))][0]
59+
60+
ret_val = False
61+
if tf is not None:
62+
for e in tf.train.summary_iterator(file):
63+
for v in e.summary.value:
64+
if v.tag == tag:
65+
ret_val = True
66+
break
67+
if ret_val:
68+
break
69+
70+
self.assertTrue(ret_val)
71+
72+
73+
if __name__ == '__main__':
74+
unittest.main()

0 commit comments

Comments
 (0)