diff --git a/proteinfoundation/proteina.py b/proteinfoundation/proteina.py index 548eaa5a..971597fe 100644 --- a/proteinfoundation/proteina.py +++ b/proteinfoundation/proteina.py @@ -105,7 +105,9 @@ def load_autoencoder(self, cfg_exp, freeze_params=True): return None, None logger.info(f"Loading autoencoder from {ae_ckp_path}") - autoencoder = AutoEncoder.load_from_checkpoint(ae_ckp_path, strict=False) + autoencoder = AutoEncoder.load_from_checkpoint( + ae_ckp_path, strict=False, map_location=torch.device("cpu") + ) if freeze_params: for param in autoencoder.parameters(): param.requires_grad = False