diff --git a/physicsnemo/launch/utils/checkpoint.py b/physicsnemo/launch/utils/checkpoint.py index a99c7a92de..e4acb553c1 100644 --- a/physicsnemo/launch/utils/checkpoint.py +++ b/physicsnemo/launch/utils/checkpoint.py @@ -386,7 +386,18 @@ def load_checkpoint( model.load(file_name) else: file_to_load = _cache_if_needed(file_name) - model.load_state_dict(torch.load(file_to_load, map_location=device)) + missing_keys, unexpected_keys = model.load_state_dict( + torch.load(file_to_load, map_location=device) + ) + if missing_keys: + checkpoint_logging.warning( + f"Missing keys when loading {name}: {missing_keys}" + ) + if unexpected_keys: + checkpoint_logging.warning( + f"Unexpected keys when loading {name}: {unexpected_keys}" + ) + checkpoint_logging.success( f"Loaded model state dictionary {file_name} to device {device}" ) diff --git a/physicsnemo/models/diffusion/__init__.py b/physicsnemo/models/diffusion/__init__.py index 5ecf9ef280..788952c7bf 100644 --- a/physicsnemo/models/diffusion/__init__.py +++ b/physicsnemo/models/diffusion/__init__.py @@ -20,6 +20,7 @@ Conv2d, FourierEmbedding, GroupNorm, + get_group_norm, Linear, PositionalEmbedding, UNetBlock, diff --git a/physicsnemo/models/diffusion/dhariwal_unet.py b/physicsnemo/models/diffusion/dhariwal_unet.py index a59aa1f3cd..6858a44a44 100644 --- a/physicsnemo/models/diffusion/dhariwal_unet.py +++ b/physicsnemo/models/diffusion/dhariwal_unet.py @@ -23,10 +23,10 @@ from physicsnemo.models.diffusion import ( Conv2d, - GroupNorm, Linear, PositionalEmbedding, UNetBlock, + get_group_norm, ) from physicsnemo.models.meta import ModelMetaData from physicsnemo.models.module import Module @@ -264,7 +264,7 @@ def __init__( attention=(res in attn_resolutions), **block_kwargs, ) - self.out_norm = GroupNorm(num_channels=cout) + self.out_norm = get_group_norm(num_channels=cout) self.out_conv = Conv2d( in_channels=cout, out_channels=out_channels, kernel=3, **init_zero ) diff --git a/physicsnemo/models/diffusion/layers.py b/physicsnemo/models/diffusion/layers.py index e60e843316..40f03a9995 100644 --- a/physicsnemo/models/diffusion/layers.py +++ b/physicsnemo/models/diffusion/layers.py @@ -290,6 +290,69 @@ def forward(self, x): return x +def get_group_norm( + num_channels: int, + num_groups: int = 32, + min_channels_per_group: int = 4, + eps: float = 1e-5, + use_apex_gn: bool = False, + act: str = None, + amp_mode: bool = False, +): + """ + Utility function to get the GroupNorm layer, either from apex or from torch. + + Parameters + ---------- + num_channels : int + Number of channels in the input tensor. + num_groups : int, optional + Desired number of groups to divide the input channels, by default 32. + This might be adjusted based on the `min_channels_per_group`. + eps : float, optional + A small number added to the variance to prevent division by zero, by default + 1e-5. + use_apex_gn : bool, optional + A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. + Need to set this as False on cpu. Defaults to False. + act : str, optional + The activation function to use when fusing activation with GroupNorm. Defaults to None. + amp_mode : bool, optional + A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. + Notes + ----- + If `num_channels` is not divisible by `num_groups`, the actual number of groups + might be adjusted to satisfy the `min_channels_per_group` condition. + """ + + num_groups = min(num_groups, (num_channels + min_channels_per_group - 1) // min_channels_per_group) + if num_channels % num_groups != 0: + raise ValueError( + "num_channels must be divisible by num_groups or min_channels_per_group" + ) + + if use_apex_gn and not _is_apex_available: + raise ValueError("'apex' is not installed, set `use_apex_gn=False`") + + act = act.lower() if act else act + if use_apex_gn: + return ApexGroupNorm( + num_groups=num_groups, + num_channels=num_channels, + eps=eps, + affine=True, + act=act, + ) + else: + return GroupNorm( + num_groups=num_groups, + num_channels=num_channels, + eps=eps, + act=act, + amp_mode=amp_mode, + ) + + class GroupNorm(torch.nn.Module): """ A custom Group Normalization layer implementation. @@ -301,22 +364,13 @@ class GroupNorm(torch.nn.Module): Parameters ---------- + num_groups : int + Desired number of groups to divide the input channels. num_channels : int Number of channels in the input tensor. - num_groups : int, optional - Desired number of groups to divide the input channels, by default 32. - This might be adjusted based on the `min_channels_per_group`. - min_channels_per_group : int, optional - Minimum channels required per group. This ensures that no group has fewer - channels than this number. By default 4. eps : float, optional A small number added to the variance to prevent division by zero, by default 1e-5. - use_apex_gn : bool, optional - A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. - Need to set this as False on cpu. Defaults to False. - fused_act : bool, optional - Whether to fuse the activation function with GroupNorm. Defaults to False. act : str, optional The activation function to use when fusing activation with GroupNorm. Defaults to None. amp_mode : bool, optional @@ -329,68 +383,32 @@ class GroupNorm(torch.nn.Module): def __init__( self, + num_groups: int, num_channels: int, - num_groups: int = 32, - min_channels_per_group: int = 4, eps: float = 1e-5, - use_apex_gn: bool = False, - fused_act: bool = False, act: str = None, amp_mode: bool = False, ): - if fused_act and act is None: - raise ValueError("'act' must be specified when 'fused_act' is set to True.") - super().__init__() - self.num_groups = min( - num_groups, - (num_channels + min_channels_per_group - 1) // min_channels_per_group, - ) - if num_channels % self.num_groups != 0: - raise ValueError( - "num_channels must be divisible by num_groups or min_channels_per_group" - ) + self.num_groups = num_groups self.eps = eps self.weight = torch.nn.Parameter(torch.ones(num_channels)) self.bias = torch.nn.Parameter(torch.zeros(num_channels)) - if use_apex_gn and not _is_apex_available: - raise ValueError("'apex' is not installed, set `use_apex_gn=False`") - self.use_apex_gn = use_apex_gn - self.fused_act = fused_act self.act = act.lower() if act else act self.act_fn = None - self.amp_mode = amp_mode - if self.use_apex_gn: - if self.act: - self.gn = ApexGroupNorm( - num_groups=self.num_groups, - num_channels=num_channels, - eps=self.eps, - affine=True, - act=self.act, - ) - - else: - self.gn = ApexGroupNorm( - num_groups=self.num_groups, - num_channels=num_channels, - eps=self.eps, - affine=True, - ) - if self.fused_act: + if self.act is not None: self.act_fn = self.get_activation_function() + self.amp_mode = amp_mode def forward(self, x): weight, bias = self.weight, self.bias if not self.amp_mode: - if not self.use_apex_gn: - if weight.dtype != x.dtype: - weight = self.weight.to(x.dtype) - if bias.dtype != x.dtype: - bias = self.bias.to(x.dtype) - if self.use_apex_gn: - x = self.gn(x) - elif self.training: + if weight.dtype != x.dtype: + weight = self.weight.to(x.dtype) + if bias.dtype != x.dtype: + bias = self.bias.to(x.dtype) + + if self.training: # Use default torch implementation of GroupNorm for training # This does not support channels last memory format x = torch.nn.functional.group_norm( @@ -400,8 +418,6 @@ def forward(self, x): bias=bias, eps=self.eps, ) - if self.fused_act: - x = self.act_fn(x) else: # Use custom GroupNorm implementation that supports channels last # memory layout for inference @@ -418,8 +434,8 @@ def forward(self, x): bias = rearrange(bias, "c -> 1 c 1 1") x = x * weight + bias - if self.fused_act: - x = self.act_fn(x) + if self.act_fn is not None: + x = self.act_fn(x) return x def get_activation_function(self): @@ -583,11 +599,10 @@ def __init__( self.adaptive_scale = adaptive_scale self.profile_mode = profile_mode self.amp_mode = amp_mode - self.norm0 = GroupNorm( + self.norm0 = get_group_norm( num_channels=in_channels, eps=eps, use_apex_gn=use_apex_gn, - fused_act=True, act=act, amp_mode=amp_mode, ) @@ -609,19 +624,18 @@ def __init__( **init, ) if self.adaptive_scale: - self.norm1 = GroupNorm( + self.norm1 = get_group_norm( num_channels=out_channels, eps=eps, use_apex_gn=use_apex_gn, amp_mode=amp_mode, ) else: - self.norm1 = GroupNorm( + self.norm1 = get_group_norm( num_channels=out_channels, eps=eps, use_apex_gn=use_apex_gn, act=act, - fused_act=True, amp_mode=amp_mode, ) self.conv1 = Conv2d( @@ -650,7 +664,7 @@ def __init__( ) if self.num_heads: - self.norm2 = GroupNorm( + self.norm2 = get_group_norm( num_channels=out_channels, eps=eps, use_apex_gn=use_apex_gn, diff --git a/physicsnemo/models/diffusion/song_unet.py b/physicsnemo/models/diffusion/song_unet.py index 7e48f60bf1..199283df53 100644 --- a/physicsnemo/models/diffusion/song_unet.py +++ b/physicsnemo/models/diffusion/song_unet.py @@ -27,10 +27,10 @@ from physicsnemo.models.diffusion import ( Conv2d, FourierEmbedding, - GroupNorm, Linear, PositionalEmbedding, UNetBlock, + get_group_norm, ) from physicsnemo.models.meta import ModelMetaData from physicsnemo.models.module import Module @@ -483,7 +483,7 @@ def __init__( resample_filter=resample_filter, amp_mode=amp_mode, ) - self.dec[f"{res}x{res}_aux_norm"] = GroupNorm( + self.dec[f"{res}x{res}_aux_norm"] = get_group_norm( num_channels=cout, eps=1e-6, use_apex_gn=use_apex_gn, @@ -829,7 +829,9 @@ def __init__( if self.gridtype == "learnable": self.pos_embd = self._get_positional_embedding() else: - self.register_buffer("pos_embd", self._get_positional_embedding().float()) + self.register_buffer( + "pos_embd", self._get_positional_embedding().float(), persistent=False + ) self.lead_time_mode = lead_time_mode if self.lead_time_mode: self.lead_time_channels = lead_time_channels diff --git a/physicsnemo/models/module.py b/physicsnemo/models/module.py index ff5191139d..ad94fd1d87 100644 --- a/physicsnemo/models/module.py +++ b/physicsnemo/models/module.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import importlib import inspect import json @@ -30,11 +29,36 @@ import physicsnemo from physicsnemo.models.meta import ModelMetaData -from physicsnemo.models.util_compatibility import convert_ckp_apex from physicsnemo.registry import ModelRegistry from physicsnemo.utils.filesystem import _download_cached, _get_fs +def load_state_dict_with_logging( + module: torch.nn.Module, state_dict: Dict[str, Any], *args, **kwargs +): + """Load state dictionary and log missing and unexpected keys + + Parameters + ---------- + module : torch.nn.Module + Module to load state dictionary into + state_dict : Dict[str, Any] + State dictionary to load + *args, **kwargs + Additional arguments to pass to load_state_dict + """ + missing_keys, unexpected_keys = module.load_state_dict(state_dict, *args, **kwargs) + if missing_keys: + logging.warning( + f"Missing keys when loading {module.__class__.__name__}: {missing_keys}" + ) + if unexpected_keys: + logging.warning( + f"Unexpected keys when loading {module.__class__.__name__}: {unexpected_keys}" + ) + return missing_keys, unexpected_keys + + class Module(torch.nn.Module): """The base class for all network models in PhysicsNeMo. @@ -417,11 +441,11 @@ def load( model_dict = torch.load( local_path.joinpath("model.pt"), map_location=device ) - self.load_state_dict(model_dict, strict=strict) + load_state_dict_with_logging(self, model_dict, strict=strict) @classmethod def from_checkpoint( - cls, file_name: str, override_args: Optional[Dict[str, Any]] = None + cls, file_name: str, override_args: Optional[Dict[str, Any]] = None, strict: bool = True, ) -> "Module": """Simple utility for constructing a model from a checkpoint @@ -447,6 +471,8 @@ def from_checkpoint( class attribute. Attempting to override any other argument will raise a ``ValueError``. This API should be used with caution and only if you fully understand the implications of the override. + strict : bool, optional + Whether to strictly enforce that the keys in state_dict match, by default True Returns ------- @@ -478,8 +504,6 @@ class attribute. Attempting to override any other argument will raise with open(local_path.joinpath("args.json"), "r") as f: args = json.load(f) - ckp_args = copy.deepcopy(args) - # Load metadata to get version with open(local_path.joinpath("metadata.json"), "r") as f: metadata = json.load(f) @@ -515,8 +539,7 @@ class attribute. Attempting to override any other argument will raise local_path.joinpath("model.pt"), map_location=model.device ) - model_dict = convert_ckp_apex(ckp_args, override_args, model_dict) - model.load_state_dict(model_dict, strict=False) + load_state_dict_with_logging(model, model_dict, strict=strict) return model @staticmethod diff --git a/physicsnemo/models/util_compatibility.py b/physicsnemo/models/util_compatibility.py deleted file mode 100644 index 361c3383ce..0000000000 --- a/physicsnemo/models/util_compatibility.py +++ /dev/null @@ -1,102 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict - - -def convert_ckp_apex( - ckp_args_dict: Dict[str, Any], - model_args: Dict[str, Any], - model_dict: Dict[str, Any], -) -> Dict[str, Any]: - - """Utility for converting Apex GroupNorm-related keys in a checkpoint. - - This function modifies the checkpoint arguments and model dictionary - to ensure compatibility when switching between Apex-optimized models - and standard PyTorch models. - - Parameters - ---------- - ckp_args_dict : Dict[str, Any] - Dictionary of checkpoint arguments (e.g., configuration parameters saved during training). - model_args : Dict[str, Any] - Dictionary of model initialization arguments that may need updating. - model_dict : Dict[str, Any] - Dictionary containing model state_dict (weights) loaded from checkpoint. - - Returns - ------- - Dict[str, Any] - Updated model_dict with necessary key modifications applied for compatibility. - - Raises - ------ - KeyError - If essential expected keys are missing during the conversion process. - """ - - apex_in_ckp = ("use_apex_gn" in ckp_args_dict["__args__"].keys()) and ( - ckp_args_dict["__args__"]["use_apex_gn"] - ) - apex_in_workflow = ( - (model_args is not None) - and ("use_apex_gn" in model_args.keys()) - and (model_args["use_apex_gn"]) - ) - - filtered_state_dict = {} - # case1: try to use non-optimized ckp in optimized workflow - if (not apex_in_ckp) and apex_in_workflow: - # transfer GN weight & bias to apex GN weight & bias - for key, value in model_dict.items(): - is_duplicate = False - for norm_layer in ["norm0", "norm1", "norm2", "aux_norm"]: - if f"{norm_layer}.weight" in key: - new_key = key.replace( - f"{norm_layer}.weight", f"{norm_layer}.gn.weight" - ) - filtered_state_dict[new_key] = value - is_duplicate = True - elif f"{norm_layer}.bias" in key: - new_key = key.replace(f"{norm_layer}.bias", f"{norm_layer}.gn.bias") - filtered_state_dict[new_key] = value - is_duplicate = True - if not is_duplicate: - filtered_state_dict[key] = value - - # case2: try to use optimized ckp in non-optimized workflow - elif apex_in_ckp and (not apex_in_workflow): - # transfer apex GN weight & bias to GN weight & bias - for key, value in model_dict.items(): - is_duplicate = False - for norm_layer in ["norm0", "norm1", "norm2", "aux_norm"]: - if f"{norm_layer}.gn.weight" in key: - new_key = key.replace( - f"{norm_layer}.gn.weight", f"{norm_layer}.weight" - ) - filtered_state_dict[new_key] = value - is_duplicate = True - elif f"{norm_layer}.bias" in key: - new_key = key.replace(f"{norm_layer}.gn.bias", f"{norm_layer}.bias") - filtered_state_dict[new_key] = value - is_duplicate = True - if not is_duplicate: - filtered_state_dict[key] = value - else: - # no need to convert ckp - return model_dict - - return filtered_state_dict