@@ -187,21 +187,12 @@ def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,
187187
188188 logger .info ("Attempting to load state from previous \
189189 training from %s" % latest_state_path )
190-
191190 try :
192- self .update_state (latest_state_path ,
193- weights_only = False )
191+ self .update_state (latest_state_path )
194192 except KeyError :
195- try :
196- self .update_state (latest_state_path ,
197- weights_only = True )
198- self .start_epoch = max (
199- latest_epoch , self .start_epoch )
200-
201- except KeyError :
202- logger .warn ("Previous State could not be loaded, \
203- although it exists.Training will be \
204- restarted" )
193+ logger .warn ("Previous State could not be loaded, \
194+ although it exists.Training will be \
195+ restarted" )
205196
206197 # asssign closure and prepare batch from network
207198 self .closure_fn = network .closure
@@ -363,8 +354,7 @@ def _at_training_end(self):
363354
364355 # load best model and return it
365356 self .update_state (os .path .join (self .save_path ,
366- 'checkpoint_best.pth' ),
367- weights_only = True
357+ 'checkpoint_best.pth' )
368358 )
369359
370360 return self .module
@@ -394,8 +384,7 @@ def _at_epoch_begin(self, metrics_val, val_score_key, epoch, num_epochs,
394384 for cb in self ._callbacks :
395385 self ._update_state (cb .at_epoch_begin (self , val_metrics = metrics_val ,
396386 val_score_key = val_score_key ,
397- curr_epoch = epoch ),
398- weights_only = False )
387+ curr_epoch = epoch ))
399388
400389 def _at_epoch_end (self , metrics_val , val_score_key , epoch , is_best ,
401390 ** kwargs ):
@@ -423,18 +412,17 @@ def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,
423412 for cb in self ._callbacks :
424413 self ._update_state (cb .at_epoch_end (self , val_metrics = metrics_val ,
425414 val_score_key = val_score_key ,
426- curr_epoch = epoch ),
427- weights_only = False )
415+ curr_epoch = epoch ))
428416
429417 if epoch % self .save_freq == 0 :
430418 self .save_state (os .path .join (self .save_path ,
431419 "checkpoint_epoch_%d.pth" % epoch ),
432- epoch , False )
420+ epoch )
433421
434422 if is_best :
435423 self .save_state (os .path .join (self .save_path ,
436424 "checkpoint_best.pth" ),
437- epoch , False )
425+ epoch )
438426
439427 def _train_single_epoch (self , batchgen : MultiThreadedAugmenter , epoch ):
440428 """
@@ -597,7 +585,7 @@ def predict(self, batchgen, batch_size=None):
597585
598586 return outputs_all , labels_all , val_dict
599587
600- def save_state (self , file_name , epoch , weights_only = False , ** kwargs ):
588+ def save_state (self , file_name , epoch , ** kwargs ):
601589 """
602590 saves the current state via :func:`delira.io.torch.save_checkpoint`
603591
@@ -607,28 +595,24 @@ def save_state(self, file_name, epoch, weights_only=False, **kwargs):
607595 filename to save the state to
608596 epoch : int
609597 current epoch (will be saved for mapping back)
610- weights_only : bool
611- whether to store only weights (default: False)
612598 *args :
613599 positional arguments
614600 **kwargs :
615601 keyword arguments
616602
617603 """
618- save_checkpoint (file_name , self .module , self .optimizers , weights_only ,
619- ** kwargs )
604+ save_checkpoint (file_name , self .module , self .optimizers ,
605+ epoch = epoch , ** kwargs )
620606
621607 @staticmethod
622- def load_state (file_name , weights_only = True , ** kwargs ):
608+ def load_state (file_name , ** kwargs ):
623609 """
624610 Loads the new state from file via :func:`delira.io.torch.load_checkpoint`
625611
626612 Parameters
627613 ----------
628614 file_name : str
629615 the file to load the state from
630- weights_only : bool
631- whether file contains stored weights only (default: False)
632616 **kwargs : keyword arguments
633617
634618 Returns
@@ -637,24 +621,16 @@ def load_state(file_name, weights_only=True, **kwargs):
637621 new state
638622
639623 """
640- if weights_only :
641- return load_checkpoint (file_name , weights_only , ** kwargs )
642- else :
643- model , optimizer , epoch = load_checkpoint (file_name , weights_only ,
644- ** kwargs )
645- return {"module" : model , "optimizers" : optimizer ,
646- "start_epoch" : epoch }
624+ return load_checkpoint (file_name , ** kwargs )
647625
648- def update_state (self , file_name , weights_only = True , * args , ** kwargs ):
626+ def update_state (self , file_name , * args , ** kwargs ):
649627 """
650628 Update internal state from a loaded state
651629
652630 Parameters
653631 ----------
654632 file_name : str
655633 file containing the new state to load
656- weights_only : bool
657- whether to update only weights or notS
658634 *args :
659635 positional arguments
660636 **kwargs :
@@ -666,46 +642,35 @@ def update_state(self, file_name, weights_only=True, *args, **kwargs):
666642 the trainer with a modified state
667643
668644 """
669- self ._update_state (self .load_state (file_name , weights_only ,
670- * args , ** kwargs ), weights_only )
645+ self ._update_state (self .load_state (file_name , * args , ** kwargs ))
671646
672- def _update_state (self , new_state , weights_only = True ):
647+ def _update_state (self , new_state ):
673648 """
674649 Update the state from a given new state
675650
676651 Parameters
677652 ----------
678653 new_state : dict
679654 new state to update internal state from
680- weights_only : bool
681- whether to update weights only from statedict or update
682- everything
683655
684656 Returns
685657 -------
686658 :class:`PyTorchNetworkTrainer`
687659 the trainer with a modified state
688660
689- # """
661+ """
690662 # print(",".join(new_state.keys()))
691663
692- if weights_only :
693- if "model" in new_state :
694- model_state = new_state ["model" ]
695- else :
696- model_state = new_state
697-
698- self .module .load_state_dict (model_state )
664+ if "model" in new_state :
665+ self .module .load_state_dict (new_state .pop ("model" ))
699666
700- if "optimizer" in new_state and new_state ["optimizer" ]:
701- for key in self .optimizers .keys ():
702- self .optimizers [key ].load_state_dict (
703- new_state ["optimizer" ][key ])
667+ if "optimizer" in new_state and new_state ["optimizer" ]:
668+ optim_state = new_state .pop ("optimizer" )
669+ for key in self .optimizers .keys ():
670+ self .optimizers [key ].load_state_dict (
671+ optim_state [key ])
704672
705- if "epoch" in new_state :
706- self .start_epoch = new_state [ "epoch" ]
673+ if "epoch" in new_state :
674+ self .start_epoch = new_state . pop ( "epoch" )
707675
708- return self
709-
710- else :
711- return super ()._update_state (new_state )
676+ return super ()._update_state (new_state )
0 commit comments