Skip to content

Commit f453f64

Browse files
Log best epoch when loading checkpoint (#699)
--------- Co-authored-by: Filippo Bigi <[email protected]>
1 parent 47e53a4 commit f453f64

17 files changed

+161
-104
lines changed

src/metatrain/cli/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def _eval_targets(
238238
mean_per_atom = np.mean(timings_per_atom)
239239
std_per_atom = np.std(timings_per_atom)
240240
logging.info(
241-
f"evaluation time: {total_time:.2f} s "
241+
f"Evaluation time: {total_time:.2f} s "
242242
f"[{1000.0 * mean_per_atom:.4f} ± "
243243
f"{1000.0 * std_per_atom:.4f} ms per atom]"
244244
)

src/metatrain/cli/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,11 @@ def train_model(
568568
trainer.save_checkpoint(model, checkpoint_output)
569569
except Exception as e:
570570
raise ArchitectureError(e)
571+
571572
if checkpoint_output.exists():
573+
# Reload ensuring (best) model intended for inference
574+
model = load_model(checkpoint_output)
575+
572576
logging.info(f"Final checkpoint: {checkpoint_output.absolute().resolve()}")
573577

574578
mts_atomistic_model = model.export()

src/metatrain/pet/checkpoints.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,42 @@
1-
def model_update_v1_v2(state_dict):
2-
# This if-statement is necessary to handle cases when
3-
# best_model_state_dict and model_state_dict are the same.
4-
# In that case, the both are updated within the first call of
5-
# this function in the PET.update_checkpoint() method.
6-
if (
7-
state_dict is not None
8-
and "additive_models.0.model.type_to_index" not in state_dict
9-
):
10-
state_dict["additive_models.0.model.type_to_index"] = state_dict.pop(
11-
"additive_models.0.type_to_index"
12-
)
1+
###########################
2+
# MODEL ###################
3+
###########################
4+
5+
6+
def model_update_v1_v2(checkpoint):
7+
for key in ["model_state_dict", "best_model_state_dict"]:
8+
if (state_dict := checkpoint.get(key)) is not None:
9+
state_dict["additive_models.0.model.type_to_index"] = state_dict.pop(
10+
"additive_models.0.type_to_index"
11+
)
12+
13+
14+
def model_update_v2_v3(checkpoint):
15+
for key in ["model_state_dict", "best_model_state_dict"]:
16+
if (state_dict := checkpoint.get(key)) is not None:
17+
if "train_hypers" in state_dict:
18+
finetune_config = state_dict["train_hypers"].get("finetune", {})
19+
else:
20+
finetune_config = {}
21+
state_dict["finetune_config"] = finetune_config
22+
23+
24+
def model_update_v3_v4(checkpoint):
25+
checkpoint["epoch"] = checkpoint.get("epoch")
26+
checkpoint["best_epoch"] = checkpoint.get("best_epoch")
27+
28+
if checkpoint["best_model_state_dict"] is not None:
29+
checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict")
30+
31+
32+
###########################
33+
# TRAINER #################
34+
###########################
1335

1436

1537
def trainer_update_v1_v2(checkpoint):
1638
checkpoint["train_hypers"] = checkpoint["train_hypers"].get("scheduler_factor", 0.5)
1739

1840

19-
def model_update_v2_v3(state_dict):
20-
if state_dict is not None:
21-
if "train_hypers" in state_dict:
22-
finetune_config = state_dict["train_hypers"].get("finetune", {})
23-
else:
24-
finetune_config = {}
25-
state_dict["finetune_config"] = finetune_config
41+
def trainer_update_v2_v3(checkpoint):
42+
checkpoint["best_epoch"] = checkpoint.get("best_epoch")

src/metatrain/pet/model.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import warnings
23
from math import prod
34
from typing import Any, Dict, List, Literal, Optional
@@ -40,7 +41,7 @@ class PET(ModelInterface):
4041
4142
"""
4243

43-
__checkpoint_version__ = 3
44+
__checkpoint_version__ = 4
4445
__supported_devices__ = ["cuda", "cpu"]
4546
__supported_dtypes__ = [torch.float32, torch.float64]
4647
__default_metadata__ = ModelMetadata(
@@ -686,25 +687,23 @@ def load_checkpoint(
686687
checkpoint: Dict[str, Any],
687688
context: Literal["restart", "finetune", "export"],
688689
) -> "PET":
689-
model_data = checkpoint["model_data"]
690-
691690
if context == "restart":
691+
logging.info(f"Using latest model from epoch {checkpoint['epoch']}")
692692
model_state_dict = checkpoint["model_state_dict"]
693-
elif context == "finetune" or context == "export":
693+
elif context in {"finetune", "export"}:
694+
logging.info(f"Using best model from epoch {checkpoint['best_epoch']}")
694695
model_state_dict = checkpoint["best_model_state_dict"]
695-
if model_state_dict is None:
696-
model_state_dict = checkpoint["model_state_dict"]
697696
else:
698697
raise ValueError("Unknown context tag for checkpoint loading!")
699698

700-
finetune_config = model_state_dict.pop("finetune_config", {})
701-
702699
# Create the model
700+
model_data = checkpoint["model_data"]
703701
model = cls(
704702
hypers=model_data["model_hypers"],
705703
dataset_info=model_data["dataset_info"],
706704
)
707705

706+
finetune_config = model_state_dict.pop("finetune_config", {})
708707
if finetune_config:
709708
# Apply the finetuning strategy
710709
model = apply_finetuning_strategy(model, finetune_config)
@@ -890,14 +889,11 @@ def _get_system_indices_and_labels(
890889

891890
@classmethod
892891
def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict:
893-
if checkpoint["model_ckpt_version"] == 1:
894-
checkpoints.model_update_v1_v2(checkpoint["model_state_dict"])
895-
checkpoints.model_update_v1_v2(checkpoint["best_model_state_dict"])
896-
checkpoint["model_ckpt_version"] = 2
897-
if checkpoint["model_ckpt_version"] == 2:
898-
checkpoints.model_update_v2_v3(checkpoint["model_state_dict"])
899-
checkpoints.model_update_v2_v3(checkpoint["best_model_state_dict"])
900-
checkpoint["model_ckpt_version"] = 3
892+
for v in range(1, cls.__checkpoint_version__):
893+
if checkpoint["model_ckpt_version"] == v:
894+
update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}")
895+
update(checkpoint)
896+
checkpoint["model_ckpt_version"] = v + 1
901897

902898
if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__:
903899
raise RuntimeError(
@@ -919,7 +915,9 @@ def get_checkpoint(self) -> Dict:
919915
"model_hypers": self.hypers,
920916
"dataset_info": self.dataset_info,
921917
},
918+
"epoch": None,
919+
"best_epoch": None,
922920
"model_state_dict": model_state_dict,
923-
"best_model_state_dict": None,
921+
"best_model_state_dict": self.state_dict(),
924922
}
925923
return checkpoint
14.5 KB
Binary file not shown.
14.5 KB
Binary file not shown.

src/metatrain/pet/tests/test_checkpoints.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import logging
23

34
import pytest
45
import torch
@@ -81,7 +82,7 @@ def model_trainer():
8182

8283

8384
@pytest.mark.parametrize("context", ["finetune", "restart", "export"])
84-
def test_get_checkpoint(context):
85+
def test_get_checkpoint(context, caplog):
8586
"""
8687
Test that the checkpoint created by the model.get_checkpoint()
8788
function can be loaded back in all possible contexts.
@@ -93,8 +94,15 @@ def test_get_checkpoint(context):
9394
)
9495
model = PET(MODEL_HYPERS, dataset_info)
9596
checkpoint = model.get_checkpoint()
97+
98+
caplog.set_level(logging.INFO)
9699
PET.load_checkpoint(checkpoint, context)
97100

101+
if context == "restart":
102+
assert "Using latest model from epoch None" in caplog.text
103+
else:
104+
assert "Using best model from epoch None" in caplog.text
105+
98106

99107
@pytest.mark.parametrize("cls_type", ["model", "trainer"])
100108
def test_failed_checkpoint_upgrade(cls_type):

src/metatrain/pet/trainer.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,15 @@ def func_lr_scheduler(epoch):
5252

5353

5454
class Trainer(TrainerInterface):
55-
__checkpoint_version__ = 2
55+
__checkpoint_version__ = 3
5656

5757
def __init__(self, hypers):
5858
super().__init__(hypers)
5959

6060
self.optimizer_state_dict = None
6161
self.scheduler_state_dict = None
6262
self.epoch = None
63+
self.best_epoch = None
6364
self.best_metric = None
6465
self.best_model_state_dict = None
6566
self.best_optimizer_state_dict = None
@@ -520,6 +521,7 @@ def train(
520521
self.best_model_state_dict = copy.deepcopy(
521522
(model.module if is_distributed else model).state_dict()
522523
)
524+
self.best_epoch = epoch
523525
self.best_optimizer_state_dict = copy.deepcopy(optimizer.state_dict())
524526

525527
if epoch % self.hypers["checkpoint_interval"] == 0:
@@ -553,6 +555,7 @@ def save_checkpoint(self, model, path: Union[str, Path]):
553555
"epoch": self.epoch,
554556
"optimizer_state_dict": self.optimizer_state_dict,
555557
"scheduler_state_dict": self.scheduler_state_dict,
558+
"best_epoch": self.best_epoch,
556559
"best_metric": self.best_metric,
557560
"best_model_state_dict": self.best_model_state_dict,
558561
"best_optimizer_state_dict": self.best_optimizer_state_dict,
@@ -570,29 +573,25 @@ def load_checkpoint(
570573
hypers: Dict[str, Any],
571574
context: Literal["restart", "finetune"],
572575
) -> "Trainer":
573-
epoch = checkpoint["epoch"]
574-
optimizer_state_dict = checkpoint["optimizer_state_dict"]
575-
scheduler_state_dict = checkpoint["scheduler_state_dict"]
576-
best_metric = checkpoint["best_metric"]
577-
best_model_state_dict = checkpoint["best_model_state_dict"]
578-
best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"]
579-
580-
# Create the trainer
581576
trainer = cls(hypers)
582-
trainer.optimizer_state_dict = optimizer_state_dict
583-
trainer.scheduler_state_dict = scheduler_state_dict
584-
trainer.epoch = epoch
585-
trainer.best_metric = best_metric
586-
trainer.best_model_state_dict = best_model_state_dict
587-
trainer.best_optimizer_state_dict = best_optimizer_state_dict
577+
trainer.optimizer_state_dict = checkpoint["optimizer_state_dict"]
578+
trainer.scheduler_state_dict = checkpoint["scheduler_state_dict"]
579+
trainer.epoch = checkpoint["epoch"]
580+
trainer.best_epoch = checkpoint["best_epoch"]
581+
trainer.best_metric = checkpoint["best_metric"]
582+
trainer.best_model_state_dict = checkpoint["best_model_state_dict"]
583+
trainer.best_optimizer_state_dict = checkpoint["best_optimizer_state_dict"]
588584

589585
return trainer
590586

591587
@classmethod
592588
def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict:
593-
if checkpoint["trainer_ckpt_version"] == 1:
594-
checkpoints.trainer_update_v1_v2(checkpoint)
595-
checkpoint["trainer_ckpt_version"] = 2
589+
for v in range(1, cls.__checkpoint_version__):
590+
if checkpoint["trainer_ckpt_version"] == v:
591+
update = getattr(checkpoints, f"trainer_update_v{v}_v{v + 1}")
592+
update(checkpoint)
593+
checkpoint["trainer_ckpt_version"] = v + 1
594+
596595
if checkpoint["trainer_ckpt_version"] != cls.__checkpoint_version__:
597596
raise RuntimeError(
598597
f"Unable to upgrade the checkpoint: the checkpoint is using "
Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
1-
def update_v1_v2(state_dict):
2-
# This if-statement is necessary to handle cases when
3-
# best_model_state_dict and model_state_dict are the same.
4-
# In that case, the both are updated within the first call of
5-
# this function in the PET.update_checkpoint() method.
6-
if (
7-
state_dict is not None
8-
and "additive_models.0.model.type_to_index" not in state_dict
9-
):
10-
state_dict["additive_models.0.model.type_to_index"] = state_dict.pop(
11-
"additive_models.0.type_to_index"
12-
)
1+
###########################
2+
# MODEL ###################
3+
###########################
4+
5+
6+
def model_update_v1_v2(checkpoint):
7+
for key in ["model_state_dict", "best_model_state_dict"]:
8+
if (state_dict := checkpoint.get(key)) is not None:
9+
state_dict["additive_models.0.model.type_to_index"] = state_dict.pop(
10+
"additive_models.0.type_to_index"
11+
)
12+
13+
14+
def model_update_v2_v3(checkpoint):
15+
checkpoint["epoch"] = checkpoint.get("epoch")
16+
checkpoint["best_epoch"] = checkpoint.get("best_epoch")
17+
18+
if checkpoint["best_model_state_dict"] is not None:
19+
checkpoint["best_model_state_dict"] = checkpoint.get("best_model_state_dict")
20+
21+
22+
###########################
23+
# TRAINER #################
24+
###########################
25+
26+
27+
def trainer_update_v1_v2(checkpoint):
28+
checkpoint["best_epoch"] = checkpoint.get("best_epoch")

src/metatrain/soap_bpnn/model.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Any, Dict, List, Literal, Optional
23

34
import metatensor.torch as mts
@@ -171,7 +172,7 @@ def concatenate_structures(
171172

172173

173174
class SoapBpnn(ModelInterface):
174-
__checkpoint_version__ = 2
175+
__checkpoint_version__ = 3
175176
__supported_devices__ = ["cuda", "cpu"]
176177
__supported_dtypes__ = [torch.float64, torch.float32]
177178
__default_metadata__ = ModelMetadata(
@@ -670,18 +671,17 @@ def load_checkpoint(
670671
checkpoint: Dict[str, Any],
671672
context: Literal["restart", "finetune", "export"],
672673
) -> "SoapBpnn":
673-
model_data = checkpoint["model_data"]
674-
675674
if context == "restart":
675+
logging.info(f"Using latest model from epoch {checkpoint['epoch']}")
676676
model_state_dict = checkpoint["model_state_dict"]
677-
elif context == "finetune" or context == "export":
677+
elif context in {"finetune", "export"}:
678+
logging.info(f"Using best model from epoch {checkpoint['best_epoch']}")
678679
model_state_dict = checkpoint["best_model_state_dict"]
679-
if model_state_dict is None:
680-
model_state_dict = checkpoint["model_state_dict"]
681680
else:
682681
raise ValueError("Unknown context tag for checkpoint loading!")
683682

684683
# Create the model
684+
model_data = checkpoint["model_data"]
685685
model = cls(
686686
hypers=model_data["model_hypers"],
687687
dataset_info=model_data["dataset_info"],
@@ -858,10 +858,11 @@ def _add_output(self, target_name: str, target: TargetInfo) -> None:
858858

859859
@classmethod
860860
def upgrade_checkpoint(cls, checkpoint: Dict) -> Dict:
861-
if checkpoint["model_ckpt_version"] == 1:
862-
checkpoints.update_v1_v2(checkpoint["model_state_dict"])
863-
checkpoints.update_v1_v2(checkpoint["best_model_state_dict"])
864-
checkpoint["model_ckpt_version"] = 2
861+
for v in range(1, cls.__checkpoint_version__):
862+
if checkpoint["model_ckpt_version"] == v:
863+
update = getattr(checkpoints, f"model_update_v{v}_v{v + 1}")
864+
update(checkpoint)
865+
checkpoint["model_ckpt_version"] = v + 1
865866

866867
if checkpoint["model_ckpt_version"] != cls.__checkpoint_version__:
867868
raise RuntimeError(
@@ -880,8 +881,10 @@ def get_checkpoint(self) -> Dict:
880881
"model_hypers": self.hypers,
881882
"dataset_info": self.dataset_info,
882883
},
884+
"epoch": None,
885+
"best_epoch": None,
883886
"model_state_dict": self.state_dict(),
884-
"best_model_state_dict": None,
887+
"best_model_state_dict": self.state_dict(),
885888
}
886889
return checkpoint
887890

0 commit comments

Comments
 (0)