diff --git a/physicsnemo/models/module.py b/physicsnemo/models/module.py index 4eaf73b138..2b4f5aff44 100644 --- a/physicsnemo/models/module.py +++ b/physicsnemo/models/module.py @@ -18,6 +18,7 @@ import importlib import inspect +import io import json import keyword import logging @@ -26,6 +27,7 @@ import tarfile import tempfile import warnings +import zipfile from pathlib import Path from typing import Any, Dict, Optional, Set, Union @@ -383,21 +385,41 @@ def debug(self): # TODO: set up debug log # fh = logging.FileHandler(f'physicsnemo-core-{self.meta.name}.log') - def save(self, file_name: Union[str, None] = None, verbose: bool = False) -> None: - """Simple utility for saving just the model + def save( + self, + file_name: Union[str, None] = None, + verbose: bool = False, + legacy_format: bool = False, + ) -> None: + """ + Utility method for saving a ``Module`` instance to a '.mdlus' checkpoint file. Parameters ---------- - file_name : Union[str,None], optional - File name to save model weight to. When none is provide it will default to - the model's name set in the meta data, by default None - verbose : bool, optional - Whether to save the model in verbose mode which will include git hash, etc, by default False + file_name : Union[str,None], optional, default=None + File name to save the model checkpoint to. When ``None`` is provided it will default to + the model's name set in the meta data (the model's metadata must + have a 'name' attribute in this case). + verbose : bool, optional, default=False + Whether to save the model in verbose mode which will include git hash, etc. + legacy_format : bool, optional, default=False + Whether to save the model in legacy tar format. If True, saves as tar archive. + If False (default), saves as zip archive. Raises ------ ValueError If file_name does not end with .mdlus extension + + Examples + -------- + >>> from physicsnemo.models.mlp import FullyConnected + >>> model = FullyConnected(in_features=32, out_features=64) + >>> # Save a checkpoint with the default file name 'FullyConnected.mdlus'. + >>> # In this case, the model.meta.name coincides with the model class name, but that is not always the case. + >>> model.save() + >>> # Save a checkpoint to a specified file name 'my_model.mdlus' + >>> model.save("my_model.mdlus") """ # Define some helper functions @@ -478,50 +500,127 @@ def _save_process(module, args, metadata, mod_prefix="") -> None: self._orig_mod.save(file_name, verbose) return - with tempfile.TemporaryDirectory() as temp_dir: - local_path = Path(temp_dir) + # Save the physicsnemo version and git hash (if available) + metadata_info = { + "physicsnemo_version": physicsnemo.__version__, + "mdlus_file_version": self.__model_checkpoint_version__, + } - torch.save(self.state_dict(), local_path / "model.pt") + if verbose: + import git - # Save the physicsnemo version and git hash (if available) - metadata_info = { - "physicsnemo_version": physicsnemo.__version__, - "mdlus_file_version": self.__model_checkpoint_version__, - } + try: + repo = git.Repo(search_parent_directories=True) + metadata_info["git_hash"] = repo.head.object.hexsha + except git.InvalidGitRepositoryError: + metadata_info["git_hash"] = None + + # Copy self._args to avoid side effects + _args = self._args.copy() + + # Recursively populate _args and metadata_info with submodules + # information + _save_process(self, _args, metadata_info) + + # If file_name is not provided, use the model's name from the metadata + if file_name is None: + meta_name = getattr(self.meta, "name", None) + if meta_name is None: + raise ValueError( + "Model metadata does not have a 'name' attribute, please set it " + "explicitly or pass a 'file_name' argument to save a checkpoint." + ) + file_name = f"{meta_name}.mdlus" - if verbose: - import git + # Write checkpoint file + fs = _get_fs(file_name) - try: - repo = git.Repo(search_parent_directories=True) - metadata_info["git_hash"] = repo.head.object.hexsha - except git.InvalidGitRepositoryError: - metadata_info["git_hash"] = None + if not legacy_format: + # Save in zip format (default) + try: + with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as tmp: + tmp_path = tmp.name + + with zipfile.ZipFile(tmp_path, "w", zipfile.ZIP_DEFLATED) as archive: + # Save model state dict + state_dict_buffer = io.BytesIO() + torch.save(self.state_dict(), state_dict_buffer) + archive.writestr("model.pt", state_dict_buffer.getvalue()) + + # Save args + args_str = json.dumps(_args) + archive.writestr("args.json", args_str) + + # Save metadata + metadata_str = json.dumps(metadata_info) + archive.writestr("metadata.json", metadata_str) + + # Upload to final destination + fs.put(tmp_path, file_name) + finally: + # Clean up temporary file + if os.path.exists(tmp_path): + os.remove(tmp_path) + else: + # Save in legacy tar format + with tempfile.TemporaryDirectory() as temp_dir: + local_path = Path(temp_dir) - # Copy self._args to avoid side effects - _args = self._args.copy() + # Save model state dict + torch.save(self.state_dict(), local_path / "model.pt") - # Recursively populate _args and metadata_info with submodules - # information - _save_process(self, _args, metadata_info) + # Save args + with open(local_path / "args.json", "w") as f: + json.dump(_args, f) - with open(local_path / "args.json", "w") as f: - json.dump(_args, f) + # Save metadata + with open(local_path / "metadata.json", "w") as f: + json.dump(metadata_info, f) - with open(local_path / "metadata.json", "w") as f: - json.dump(metadata_info, f) + # Create tar archive + with tarfile.open(local_path / "model.tar", "w") as tar: + for file in local_path.iterdir(): + tar.add(str(file), arcname=file.name) - # Once all files are saved, package them into a tar file - with tarfile.open(local_path / "model.tar", "w") as tar: - for file in local_path.iterdir(): - tar.add(str(file), arcname=file.name) + # Upload to final destination + fs.put(local_path / "model.tar", file_name) - if file_name is None: - file_name = self.meta.name + ".mdlus" + @staticmethod + def _detect_checkpoint_format(file_path: str) -> str: + """Detect whether checkpoint is zip or tar format - # Save files to remote destination - fs = _get_fs(file_name) - fs.put(str(local_path / "model.tar"), file_name) + Parameters + ---------- + file_path : str + Path to checkpoint file + + Returns + ------- + str + Either 'zip' or 'tar' + + Raises + ------ + IOError + If file format cannot be determined + """ + try: + # NOTE: the check for tarfile MUST come first, as older checkpoints + # will be both zip and tar archives, but newer checkpoints will + # only be zip. + if tarfile.is_tarfile(file_path): + return "tar" + elif zipfile.is_zipfile(file_path): + return "zip" + else: + raise IOError( + f"Checkpoint file {file_path} is neither a valid zip " + f"nor tar archive" + ) + except Exception as e: + raise IOError( + f"Could not determine checkpoint format for {file_path}: {e}" + ) from e @staticmethod def _check_checkpoint(local_path: Path | str) -> None: @@ -537,51 +636,109 @@ def load( map_location: Union[None, str, torch.device] = None, strict: bool = True, ) -> None: - """Simple utility for loading the model weights from checkpoint + """ + Utility method for loading the model weights from a '.mdlus' + checkpoint file. Unlike + :meth:`~physicsnemo.models.module.Module.from_checkpoint`, this method + *does not* instantiate the model, but rather loads the ``state_dict`` for an + already instantiated model. Parameters ---------- file_name : str - Checkpoint file name - map_location : Union[None, str, torch.device], optional - Map location for loading the model weights, by default None will use model's device - strict: bool, optional - whether to strictly enforce that the keys in state_dict match, by default True + Checkpoint file name. Must be a valid '.mdlus' checkpoint file. + map_location : Union[None, str, torch.device], optional, default=None + Map location for loading the model weights, ``None`` will use the model's device. + strict: bool, optional, default=True + Whether to strictly enforce that the keys in ``state_dict`` match. Raises ------ IOError - If file_name provided does not exist or is not a valid checkpoint + If ``file_name`` provided does not exist or is not a valid checkpoint + + Examples + -------- + Basic example loading the model weights (state_dict) from a checkpoint: + + .. code-block:: python + + from physicsnemo.models.mlp import FullyConnected + + # Create a model with the same architecture as the saved one + model = FullyConnected(in_features=32, out_features=64) + + # Load the weights from checkpoint + model.load("FullyConnected.mdlus") + + Loading with specific device mapping: + + .. code-block:: python + + import torch + from physicsnemo.models.mlp import FullyConnected + + model = FullyConnected(in_features=32, out_features=64) + + # Load checkpoint to CPU even if it was saved on GPU + model.load("FullyConnected.mdlus", map_location="cpu") + + # Or load to a specific GPU + model.load("FullyConnected.mdlus", map_location=torch.device("cuda:0")) """ # Download and cache the checkpoint file if needed cached_file_name = _download_cached(file_name) - # Use a temporary directory to extract the tar file - with tempfile.TemporaryDirectory() as temp_dir: - local_path = Path(temp_dir) - - # Open the tar file and extract its contents to the temporary directory - with tarfile.open(cached_file_name, "r") as tar: - # Safely extract while supporting Python versions < 3.12 that lack the - # ``filter`` keyword. Starting with 3.12, ``filter="data"`` is the - # recommended way to avoid unsafe members - extract_kwargs = dict( - path=local_path, - members=list(Module._safe_members(tar, local_path)), - ) - if "filter" in tar.extractall.__code__.co_varnames: - extract_kwargs["filter"] = "data" - tar.extractall(**extract_kwargs) # noqa: S202 + # Detect checkpoint format + checkpoint_format = Module._detect_checkpoint_format(cached_file_name) - # Check if the checkpoint is valid - Module._check_checkpoint(local_path) + device = map_location if map_location is not None else self.device - # Load the model weights - device = map_location if map_location is not None else self.device - model_dict = torch.load( - local_path.joinpath("model.pt"), map_location=device - ) + if checkpoint_format == "zip": + # Load directly from zip file (no extraction needed) + with zipfile.ZipFile(cached_file_name, "r") as archive: + # Check if all expected files are present + expected_files = ["args.json", "metadata.json", "model.pt"] + archive_files = archive.namelist() + for expected_file in expected_files: + if expected_file not in archive_files: + raise IOError(f"File '{expected_file}' not found in checkpoint") + + # Read into memory + model_bytes = archive.read("model.pt") + + # Load state dict after closing archive + model_dict = torch.load(io.BytesIO(model_bytes), map_location=device) + + # Load state_dict into the model + _load_state_dict_with_logging(self, model_dict, strict=strict) + + else: # tar format (backward compatibility) + # Use a temporary directory to extract the tar file + with tempfile.TemporaryDirectory() as temp_dir: + local_path = Path(temp_dir) + + # Open tar file and extract contents to temporary directory + with tarfile.open(cached_file_name, "r") as tar: + # Safely extract while supporting Python < 3.12 + extract_kwargs = dict( + path=local_path, + members=list(Module._safe_members(tar, local_path)), + ) + if "filter" in tar.extractall.__code__.co_varnames: + extract_kwargs["filter"] = "data" + tar.extractall(**extract_kwargs) # noqa: S202 + + # Check if the checkpoint is valid + Module._check_checkpoint(local_path) + + # Load the model weights + model_dict = torch.load( + local_path.joinpath("model.pt"), map_location=device + ) + + # Load state dict into the model _load_state_dict_with_logging(self, model_dict, strict=strict) @classmethod @@ -591,12 +748,14 @@ def from_checkpoint( override_args: Optional[Dict[str, Any]] = None, strict: bool = True, ) -> physicsnemo.Module: - """Simple utility for constructing a model from a checkpoint + """ + Utility class method for instantiating and loading a ``Module`` + instance from a '.mdlus' checkpoint file. Parameters ---------- file_name : str - Checkpoint file name + Checkpoint file name. Must be a valid '.mdlus' checkpoint file. override_args : Optional[Dict[str, Any]], optional, default=None Dictionary of arguments to override the ``__init__`` method's arguments saved in the checkpoint. The override of arguments occurs @@ -822,48 +981,87 @@ def _from_checkpoint_process( # Download and cache the checkpoint file if needed cached_file_name = _download_cached(file_name) - # Use a temporary directory to extract the tar file - with tempfile.TemporaryDirectory() as temp_dir: - local_path = Path(temp_dir) - - # Open the tar file and extract its contents to the temporary directory - with tarfile.open(cached_file_name, "r") as tar: - # Safely extract while supporting Python versions < 3.12 that lack the - # ``filter`` keyword. Starting with 3.12, ``filter="data"`` is the - # recommended way to avoid unsafe members; - extract_kwargs = dict( - path=local_path, - members=list(Module._safe_members(tar, local_path)), + # Detect checkpoint format + checkpoint_format = Module._detect_checkpoint_format(cached_file_name) + + if checkpoint_format == "zip": + # Load directly from zip file (no extraction needed) + with zipfile.ZipFile(cached_file_name, "r") as archive: + # Check if all expected files are present + expected_files = ["args.json", "metadata.json", "model.pt"] + archive_files = archive.namelist() + for expected_file in expected_files: + if expected_file not in archive_files: + raise IOError(f"File '{expected_file}' not found in checkpoint") + + # Load model arguments and instantiate the model + with archive.open("args.json") as f: + args = json.loads(f.read().decode("utf-8")) + + # Load metadata to get version + with archive.open("metadata.json") as f: + metadata = json.loads(f.read().decode("utf-8")) + + model = _from_checkpoint_process( + cls, + args, + metadata, + override_args, + strict, ) - if "filter" in tar.extractall.__code__.co_varnames: - extract_kwargs["filter"] = "data" - tar.extractall(**extract_kwargs) # noqa: S202 - - # Check if the checkpoint is valid - Module._check_checkpoint(local_path) - - # Load model arguments and instantiate the model - with open(local_path.joinpath("args.json"), "r") as f: - args = json.load(f) - - # Load metadata to get version - with open(local_path.joinpath("metadata.json"), "r") as f: - metadata = json.load(f) - - model = _from_checkpoint_process( - cls, - args, - metadata, - override_args, - strict, - ) - # Load the model weights - model_dict = torch.load( - local_path.joinpath("model.pt"), map_location=model.device - ) + # Read into memory + model_bytes = archive.read("model.pt") + + # Load state dict after closing archive + model_dict = torch.load(io.BytesIO(model_bytes), map_location=model.device) + # Load state_dict into the model _load_state_dict_with_logging(model, model_dict, strict=strict) + + else: # tar format (backward compatibility) + # Use a temporary directory to extract the tar file + with tempfile.TemporaryDirectory() as temp_dir: + local_path = Path(temp_dir) + + # Open tar file and extract contents to temporary directory + with tarfile.open(cached_file_name, "r") as tar: + # Safely extract while supporting Python < 3.12 + extract_kwargs = dict( + path=local_path, + members=list(Module._safe_members(tar, local_path)), + ) + if "filter" in tar.extractall.__code__.co_varnames: + extract_kwargs["filter"] = "data" + tar.extractall(**extract_kwargs) # noqa: S202 + + # Check if the checkpoint is valid + Module._check_checkpoint(local_path) + + # Load model arguments and instantiate the model + with open(local_path.joinpath("args.json"), "r") as f: + args = json.load(f) + + # Load metadata to get version + with open(local_path.joinpath("metadata.json"), "r") as f: + metadata = json.load(f) + + model = _from_checkpoint_process( + cls, + args, + metadata, + override_args, + strict, + ) + + # Load the model weights + model_dict = torch.load( + local_path.joinpath("model.pt"), map_location=model.device + ) + + # Load state_dict into the model + _load_state_dict_with_logging(model, model_dict, strict=strict) + return model @staticmethod