@@ -52,14 +52,15 @@ def func_lr_scheduler(epoch):
52
52
53
53
54
54
class Trainer (TrainerInterface ):
55
- __checkpoint_version__ = 2
55
+ __checkpoint_version__ = 3
56
56
57
57
def __init__ (self , hypers ):
58
58
super ().__init__ (hypers )
59
59
60
60
self .optimizer_state_dict = None
61
61
self .scheduler_state_dict = None
62
62
self .epoch = None
63
+ self .best_epoch = None
63
64
self .best_metric = None
64
65
self .best_model_state_dict = None
65
66
self .best_optimizer_state_dict = None
@@ -520,6 +521,7 @@ def train(
520
521
self .best_model_state_dict = copy .deepcopy (
521
522
(model .module if is_distributed else model ).state_dict ()
522
523
)
524
+ self .best_epoch = epoch
523
525
self .best_optimizer_state_dict = copy .deepcopy (optimizer .state_dict ())
524
526
525
527
if epoch % self .hypers ["checkpoint_interval" ] == 0 :
@@ -553,6 +555,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
553
555
"epoch" : self .epoch ,
554
556
"optimizer_state_dict" : self .optimizer_state_dict ,
555
557
"scheduler_state_dict" : self .scheduler_state_dict ,
558
+ "best_epoch" : self .best_epoch ,
556
559
"best_metric" : self .best_metric ,
557
560
"best_model_state_dict" : self .best_model_state_dict ,
558
561
"best_optimizer_state_dict" : self .best_optimizer_state_dict ,
@@ -570,29 +573,25 @@ def load_checkpoint(
570
573
hypers : Dict [str , Any ],
571
574
context : Literal ["restart" , "finetune" ],
572
575
) -> "Trainer" :
573
- epoch = checkpoint ["epoch" ]
574
- optimizer_state_dict = checkpoint ["optimizer_state_dict" ]
575
- scheduler_state_dict = checkpoint ["scheduler_state_dict" ]
576
- best_metric = checkpoint ["best_metric" ]
577
- best_model_state_dict = checkpoint ["best_model_state_dict" ]
578
- best_optimizer_state_dict = checkpoint ["best_optimizer_state_dict" ]
579
-
580
- # Create the trainer
581
576
trainer = cls (hypers )
582
- trainer .optimizer_state_dict = optimizer_state_dict
583
- trainer .scheduler_state_dict = scheduler_state_dict
584
- trainer .epoch = epoch
585
- trainer .best_metric = best_metric
586
- trainer .best_model_state_dict = best_model_state_dict
587
- trainer .best_optimizer_state_dict = best_optimizer_state_dict
577
+ trainer .optimizer_state_dict = checkpoint ["optimizer_state_dict" ]
578
+ trainer .scheduler_state_dict = checkpoint ["scheduler_state_dict" ]
579
+ trainer .epoch = checkpoint ["epoch" ]
580
+ trainer .best_epoch = checkpoint ["best_epoch" ]
581
+ trainer .best_metric = checkpoint ["best_metric" ]
582
+ trainer .best_model_state_dict = checkpoint ["best_model_state_dict" ]
583
+ trainer .best_optimizer_state_dict = checkpoint ["best_optimizer_state_dict" ]
588
584
589
585
return trainer
590
586
591
587
@classmethod
592
588
def upgrade_checkpoint (cls , checkpoint : Dict ) -> Dict :
593
- if checkpoint ["trainer_ckpt_version" ] == 1 :
594
- checkpoints .trainer_update_v1_v2 (checkpoint )
595
- checkpoint ["trainer_ckpt_version" ] = 2
589
+ for v in range (1 , cls .__checkpoint_version__ ):
590
+ if checkpoint ["trainer_ckpt_version" ] == v :
591
+ update = getattr (checkpoints , f"trainer_update_v{ v } _v{ v + 1 } " )
592
+ update (checkpoint )
593
+ checkpoint ["trainer_ckpt_version" ] = v + 1
594
+
596
595
if checkpoint ["trainer_ckpt_version" ] != cls .__checkpoint_version__ :
597
596
raise RuntimeError (
598
597
f"Unable to upgrade the checkpoint: the checkpoint is using "
0 commit comments