Skip to content

Commit 7e8d2cc

Browse files
Fix: prevent loading best model when PEFT adapters are active (#3470)
* Fix: prevent loading best model when PEFT adapters are active (#3056) * Reuse _load_from_checkpoint in _load_best_model --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 5eb2a1b commit 7e8d2cc

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

sentence_transformers/trainer.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -594,26 +594,19 @@ def evaluation_loop(
594594
def _load_best_model(self) -> None:
595595
# Attempt to load the model from self.state.best_model_checkpoint
596596
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
597-
try:
598-
dummy_model = self.model.__class__(
599-
self.state.best_model_checkpoint,
600-
trust_remote_code=self.model.trust_remote_code,
601-
)
602-
except Exception as exc:
603-
logger.error(f"Could not load the best model from {self.state.best_model_checkpoint}. Error: {str(exc)}")
604-
return
605597

606-
# Store the best model checkpoint in the model card
607598
try:
608599
if checkpoint := self.state.best_model_checkpoint:
609600
step = checkpoint.rsplit("-", 1)[-1]
610601
self.model.model_card_data.set_best_model_step(int(step))
611602
except Exception:
612603
pass
613604

614-
# Ideally, the only changes between self.model and the dummy model are the weights
615-
# so we should be able to just copy the state dict
616-
self.model.load_state_dict(dummy_model.state_dict())
605+
try:
606+
self._load_from_checkpoint(self.state.best_model_checkpoint)
607+
except Exception as exc:
608+
logger.error(f"Could not load the best model from {self.state.best_model_checkpoint}. Error: {str(exc)}")
609+
return
617610

618611
def validate_column_names(self, dataset: Dataset, dataset_name: str | None = None) -> None:
619612
if isinstance(dataset, dict):

0 commit comments

Comments
 (0)