Skip to content
This repository was archived by the owner on Jun 26, 2021. It is now read-only.

Commit fff41cf

Browse files
Merge pull request #65 from mibaumgartner/load_save
Removed weights_only option
2 parents afa7463 + 962cd26 commit fff41cf

File tree

3 files changed

+38
-154
lines changed

3 files changed

+38
-154
lines changed

delira/io/torch.py

Lines changed: 9 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,10 @@
1111
if "TORCH" in get_backends():
1212

1313
import torch
14-
15-
from torchvision import models as t_models
16-
from torch import nn
17-
from torch.nn import functional as F
18-
from torch import optim
19-
2014
from ..models import AbstractPyTorchNetwork
2115

2216
def save_checkpoint(file: str, model=None, optimizers={},
23-
epoch=None, weights_only=True, **kwargs):
17+
epoch=None, **kwargs):
2418
"""
2519
Save model's parameters
2620
@@ -35,9 +29,6 @@ def save_checkpoint(file: str, model=None, optimizers={},
3529
dictionary containing all optimizers
3630
epoch : int
3731
current epoch (will also be pickled)
38-
weights_only : bool
39-
whether or not to save only the model's weights or also save additional
40-
information (for easy loading)
4132
4233
"""
4334
if isinstance(model, torch.nn.DataParallel):
@@ -66,39 +57,16 @@ def save_checkpoint(file: str, model=None, optimizers={},
6657
"model": model_state,
6758
"epoch": epoch}
6859

69-
if not weights_only:
70-
71-
source = inspect.getsource(_model.__class__)
72-
73-
class_name_model = _model.__class__.__name__
74-
class_names_optim = OrderedDict()
75-
76-
for key in optim_state.keys():
77-
class_names_optim[key] = optimizers[key].__class__.__name__
78-
79-
parent_class = _model.__class__.__mro__[1].__name__
60+
torch.save(state, file, **kwargs)
8061

81-
init_kwargs = _model.init_kwargs
82-
83-
torch.save({'source': source, 'cls_name_model': class_name_model,
84-
'parent_class': parent_class, 'init_kwargs': init_kwargs,
85-
'state_dict': state, 'cls_name_optim': class_names_optim},
86-
file)
87-
88-
else:
89-
torch.save(state, file)
90-
91-
def load_checkpoint(file, weights_only=True, **kwargs):
62+
def load_checkpoint(file, **kwargs):
9263
"""
9364
Loads a saved model
9465
9566
Parameters
9667
----------
9768
file : str
9869
filepath to a file containing a saved model
99-
weights_only : bool
100-
whether the file contains only weights / only weights should be
101-
returned
10270
**kwargs:
10371
Additional keyword arguments (passed to torch.load)
10472
Especially "map_location" is important to change the device the
@@ -107,57 +75,12 @@ def load_checkpoint(file, weights_only=True, **kwargs):
10775
Returns
10876
-------
10977
OrderedDict
110-
checkpoint state_dict if `weights_only=True`
111-
torch.nn.Module, OrderedDict, int
112-
Model, Optimizers, epoch with loaded state_dicts if `weights_only=False`
78+
checkpoint state_dict
11379
11480
"""
115-
if weights_only:
116-
return torch.load(file, **kwargs)
117-
else:
118-
loaded_dict = torch.load(file, **kwargs)
119-
120-
# import parent class
121-
exec("from ..models import " + loaded_dict["parent_class"])
122-
123-
# execute pickled code (to get access to class)
124-
exec(loaded_dict["source"])
125-
126-
# create class instance (default device: CPU)
127-
exec("model = " + loaded_dict["cls_name_model"] +
128-
"(**loaded_dict['init_kwargs'])")
129-
130-
# check for "map_location" kwarg and use device of first weight tensor
131-
# as default argument (weight tensors should be all on same device)
132-
if loaded_dict["state_dict"]["model"]:
133-
default_device = next(
134-
islice(
135-
loaded_dict["state_dict"]["model"].values(), 1)
136-
).device
137-
else:
138-
default_device = torch.device("cpu")
139-
140-
map_location = kwargs.get("map_location",
141-
# use slicing instead of converting to list
142-
# to avoid memory overhead
143-
default_device)
144-
145-
# push created class from CPU to suitable device
146-
locals()['model'].to(map_location)
147-
148-
locals()['model'].load_state_dict(
149-
loaded_dict["state_dict"]["model"])
150-
151-
optims = OrderedDict()
152-
153-
for key in loaded_dict["cls_name_optim"].keys():
154-
exec("_optim = optim.%s(models.parameters())" %
155-
loaded_dict["cls_name_optim"][key])
156-
157-
optims[key] = locals()['_optim']
158-
159-
for key, val in optims.items():
160-
optims[key] = val.load_state_dict(
161-
loaded_dict["state_dict"]["optimizer"][key])
81+
checkpoint = torch.load(file, **kwargs)
16282

163-
return locals()['model'], optims, loaded_dict["state_dict"]["epoch"]
83+
if not all([_key in checkpoint
84+
for _key in ["model", "optimizer", "epoch"]]):
85+
return checkpoint['state_dict']
86+
return checkpoint

delira/training/pytorch_trainer.py

Lines changed: 28 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -187,21 +187,12 @@ def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,
187187

188188
logger.info("Attempting to load state from previous \
189189
training from %s" % latest_state_path)
190-
191190
try:
192-
self.update_state(latest_state_path,
193-
weights_only=False)
191+
self.update_state(latest_state_path)
194192
except KeyError:
195-
try:
196-
self.update_state(latest_state_path,
197-
weights_only=True)
198-
self.start_epoch = max(
199-
latest_epoch, self.start_epoch)
200-
201-
except KeyError:
202-
logger.warn("Previous State could not be loaded, \
203-
although it exists.Training will be \
204-
restarted")
193+
logger.warn("Previous State could not be loaded, \
194+
although it exists.Training will be \
195+
restarted")
205196

206197
# asssign closure and prepare batch from network
207198
self.closure_fn = network.closure
@@ -363,8 +354,7 @@ def _at_training_end(self):
363354

364355
# load best model and return it
365356
self.update_state(os.path.join(self.save_path,
366-
'checkpoint_best.pth'),
367-
weights_only=True
357+
'checkpoint_best.pth')
368358
)
369359

370360
return self.module
@@ -394,8 +384,7 @@ def _at_epoch_begin(self, metrics_val, val_score_key, epoch, num_epochs,
394384
for cb in self._callbacks:
395385
self._update_state(cb.at_epoch_begin(self, val_metrics=metrics_val,
396386
val_score_key=val_score_key,
397-
curr_epoch=epoch),
398-
weights_only=False)
387+
curr_epoch=epoch))
399388

400389
def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
401390
**kwargs):
@@ -423,18 +412,17 @@ def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
423412
for cb in self._callbacks:
424413
self._update_state(cb.at_epoch_end(self, val_metrics=metrics_val,
425414
val_score_key=val_score_key,
426-
curr_epoch=epoch),
427-
weights_only=False)
415+
curr_epoch=epoch))
428416

429417
if epoch % self.save_freq == 0:
430418
self.save_state(os.path.join(self.save_path,
431419
"checkpoint_epoch_%d.pth" % epoch),
432-
epoch, False)
420+
epoch)
433421

434422
if is_best:
435423
self.save_state(os.path.join(self.save_path,
436424
"checkpoint_best.pth"),
437-
epoch, False)
425+
epoch)
438426

439427
def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch):
440428
"""
@@ -597,7 +585,7 @@ def predict(self, batchgen, batch_size=None):
597585

598586
return outputs_all, labels_all, val_dict
599587

600-
def save_state(self, file_name, epoch, weights_only=False, **kwargs):
588+
def save_state(self, file_name, epoch, **kwargs):
601589
"""
602590
saves the current state via :func:`delira.io.torch.save_checkpoint`
603591
@@ -607,28 +595,24 @@ def save_state(self, file_name, epoch, weights_only=False, **kwargs):
607595
filename to save the state to
608596
epoch : int
609597
current epoch (will be saved for mapping back)
610-
weights_only : bool
611-
whether to store only weights (default: False)
612598
*args :
613599
positional arguments
614600
**kwargs :
615601
keyword arguments
616602
617603
"""
618-
save_checkpoint(file_name, self.module, self.optimizers, weights_only,
619-
**kwargs)
604+
save_checkpoint(file_name, self.module, self.optimizers,
605+
epoch=epoch, **kwargs)
620606

621607
@staticmethod
622-
def load_state(file_name, weights_only=True, **kwargs):
608+
def load_state(file_name, **kwargs):
623609
"""
624610
Loads the new state from file via :func:`delira.io.torch.load_checkpoint`
625611
626612
Parameters
627613
----------
628614
file_name : str
629615
the file to load the state from
630-
weights_only : bool
631-
whether file contains stored weights only (default: False)
632616
**kwargs : keyword arguments
633617
634618
Returns
@@ -637,24 +621,16 @@ def load_state(file_name, weights_only=True, **kwargs):
637621
new state
638622
639623
"""
640-
if weights_only:
641-
return load_checkpoint(file_name, weights_only, **kwargs)
642-
else:
643-
model, optimizer, epoch = load_checkpoint(file_name, weights_only,
644-
**kwargs)
645-
return {"module": model, "optimizers": optimizer,
646-
"start_epoch": epoch}
624+
return load_checkpoint(file_name, **kwargs)
647625

648-
def update_state(self, file_name, weights_only=True, *args, **kwargs):
626+
def update_state(self, file_name, *args, **kwargs):
649627
"""
650628
Update internal state from a loaded state
651629
652630
Parameters
653631
----------
654632
file_name : str
655633
file containing the new state to load
656-
weights_only : bool
657-
whether to update only weights or notS
658634
*args :
659635
positional arguments
660636
**kwargs :
@@ -666,46 +642,35 @@ def update_state(self, file_name, weights_only=True, *args, **kwargs):
666642
the trainer with a modified state
667643
668644
"""
669-
self._update_state(self.load_state(file_name, weights_only,
670-
*args, **kwargs), weights_only)
645+
self._update_state(self.load_state(file_name, *args, **kwargs))
671646

672-
def _update_state(self, new_state, weights_only=True):
647+
def _update_state(self, new_state):
673648
"""
674649
Update the state from a given new state
675650
676651
Parameters
677652
----------
678653
new_state : dict
679654
new state to update internal state from
680-
weights_only : bool
681-
whether to update weights only from statedict or update
682-
everything
683655
684656
Returns
685657
-------
686658
:class:`PyTorchNetworkTrainer`
687659
the trainer with a modified state
688660
689-
# """
661+
"""
690662
# print(",".join(new_state.keys()))
691663

692-
if weights_only:
693-
if "model" in new_state:
694-
model_state = new_state["model"]
695-
else:
696-
model_state = new_state
697-
698-
self.module.load_state_dict(model_state)
664+
if "model" in new_state:
665+
self.module.load_state_dict(new_state.pop("model"))
699666

700-
if "optimizer" in new_state and new_state["optimizer"]:
701-
for key in self.optimizers.keys():
702-
self.optimizers[key].load_state_dict(
703-
new_state["optimizer"][key])
667+
if "optimizer" in new_state and new_state["optimizer"]:
668+
optim_state = new_state.pop("optimizer")
669+
for key in self.optimizers.keys():
670+
self.optimizers[key].load_state_dict(
671+
optim_state[key])
704672

705-
if "epoch" in new_state:
706-
self.start_epoch = new_state["epoch"]
673+
if "epoch" in new_state:
674+
self.start_epoch = new_state.pop("epoch")
707675

708-
return self
709-
710-
else:
711-
return super()._update_state(new_state)
676+
return super()._update_state(new_state)

tests/io/test_torch.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,10 @@ def _build_model(in_channels, n_outputs):
2525
torch.nn.Linear(64, n_outputs)
2626
)
2727

28-
2928
net = DummyNetwork(32, 1)
3029
torch_save_checkpoint("./model.pt", model=net)
31-
# fails with weights_only=False only in pytest-mode not in normal execution
32-
torch_load_checkpoint("./model.pt", weights_only=True)
30+
assert torch_load_checkpoint("./model.pt")
3331

34-
torch_save_checkpoint("./model.pt", net, weights_only=True)
35-
assert torch_load_checkpoint("./model.pt", weights_only=True)
3632

3733
if __name__ == '__main__':
3834
test_load_save()

0 commit comments

Comments
 (0)