Skip to content

Refactor GroupNorm and log unmatched state_dict keys #989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion physicsnemo/launch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,18 @@ def load_checkpoint(
model.load(file_name)
else:
file_to_load = _cache_if_needed(file_name)
model.load_state_dict(torch.load(file_to_load, map_location=device))
missing_keys, unexpected_keys = model.load_state_dict(
torch.load(file_to_load, map_location=device)
)
if missing_keys:
checkpoint_logging.warning(
f"Missing keys when loading {name}: {missing_keys}"
)
if unexpected_keys:
checkpoint_logging.warning(
f"Unexpected keys when loading {name}: {unexpected_keys}"
)

checkpoint_logging.success(
f"Loaded model state dictionary {file_name} to device {device}"
)
Expand Down
1 change: 1 addition & 0 deletions physicsnemo/models/diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Conv2d,
FourierEmbedding,
GroupNorm,
get_group_norm,
Linear,
PositionalEmbedding,
UNetBlock,
Expand Down
4 changes: 2 additions & 2 deletions physicsnemo/models/diffusion/dhariwal_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

from physicsnemo.models.diffusion import (
Conv2d,
GroupNorm,
Linear,
PositionalEmbedding,
UNetBlock,
get_group_norm,
)
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module
Expand Down Expand Up @@ -264,7 +264,7 @@ def __init__(
attention=(res in attn_resolutions),
**block_kwargs,
)
self.out_norm = GroupNorm(num_channels=cout)
self.out_norm = get_group_norm(num_channels=cout)
self.out_conv = Conv2d(
in_channels=cout, out_channels=out_channels, kernel=3, **init_zero
)
Expand Down
148 changes: 81 additions & 67 deletions physicsnemo/models/diffusion/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,69 @@ def forward(self, x):
return x


def get_group_norm(
num_channels: int,
num_groups: int = 32,
min_channels_per_group: int = 4,
eps: float = 1e-5,
use_apex_gn: bool = False,
act: str = None,
amp_mode: bool = False,
):
"""
Utility function to get the GroupNorm layer, either from apex or from torch.

Parameters
----------
num_channels : int
Number of channels in the input tensor.
num_groups : int, optional
Desired number of groups to divide the input channels, by default 32.
This might be adjusted based on the `min_channels_per_group`.
eps : float, optional
A small number added to the variance to prevent division by zero, by default
1e-5.
use_apex_gn : bool, optional
A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout.
Need to set this as False on cpu. Defaults to False.
act : str, optional
The activation function to use when fusing activation with GroupNorm. Defaults to None.
amp_mode : bool, optional
A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False.
Notes
-----
If `num_channels` is not divisible by `num_groups`, the actual number of groups
might be adjusted to satisfy the `min_channels_per_group` condition.
"""

num_groups = min(num_groups, (num_channels + min_channels_per_group - 1) // min_channels_per_group)
if num_channels % num_groups != 0:
raise ValueError(
"num_channels must be divisible by num_groups or min_channels_per_group"
)

if use_apex_gn and not _is_apex_available:
raise ValueError("'apex' is not installed, set `use_apex_gn=False`")

act = act.lower() if act else act
if use_apex_gn:
return ApexGroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
affine=True,
act=act,
)
else:
return GroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
act=act,
amp_mode=amp_mode,
)


class GroupNorm(torch.nn.Module):
"""
A custom Group Normalization layer implementation.
Expand All @@ -301,22 +364,13 @@ class GroupNorm(torch.nn.Module):

Parameters
----------
num_groups : int
Desired number of groups to divide the input channels.
num_channels : int
Number of channels in the input tensor.
num_groups : int, optional
Desired number of groups to divide the input channels, by default 32.
This might be adjusted based on the `min_channels_per_group`.
min_channels_per_group : int, optional
Minimum channels required per group. This ensures that no group has fewer
channels than this number. By default 4.
eps : float, optional
A small number added to the variance to prevent division by zero, by default
1e-5.
use_apex_gn : bool, optional
A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout.
Need to set this as False on cpu. Defaults to False.
fused_act : bool, optional
Whether to fuse the activation function with GroupNorm. Defaults to False.
act : str, optional
The activation function to use when fusing activation with GroupNorm. Defaults to None.
amp_mode : bool, optional
Expand All @@ -329,68 +383,32 @@ class GroupNorm(torch.nn.Module):

def __init__(
self,
num_groups: int,
num_channels: int,
num_groups: int = 32,
min_channels_per_group: int = 4,
eps: float = 1e-5,
use_apex_gn: bool = False,
fused_act: bool = False,
act: str = None,
amp_mode: bool = False,
):
if fused_act and act is None:
raise ValueError("'act' must be specified when 'fused_act' is set to True.")

super().__init__()
self.num_groups = min(
num_groups,
(num_channels + min_channels_per_group - 1) // min_channels_per_group,
)
if num_channels % self.num_groups != 0:
raise ValueError(
"num_channels must be divisible by num_groups or min_channels_per_group"
)
self.num_groups = num_groups
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(num_channels))
self.bias = torch.nn.Parameter(torch.zeros(num_channels))
if use_apex_gn and not _is_apex_available:
raise ValueError("'apex' is not installed, set `use_apex_gn=False`")
self.use_apex_gn = use_apex_gn
self.fused_act = fused_act
self.act = act.lower() if act else act
self.act_fn = None
self.amp_mode = amp_mode
if self.use_apex_gn:
if self.act:
self.gn = ApexGroupNorm(
num_groups=self.num_groups,
num_channels=num_channels,
eps=self.eps,
affine=True,
act=self.act,
)

else:
self.gn = ApexGroupNorm(
num_groups=self.num_groups,
num_channels=num_channels,
eps=self.eps,
affine=True,
)
if self.fused_act:
if self.act is not None:
self.act_fn = self.get_activation_function()
self.amp_mode = amp_mode

def forward(self, x):
weight, bias = self.weight, self.bias
if not self.amp_mode:
if not self.use_apex_gn:
if weight.dtype != x.dtype:
weight = self.weight.to(x.dtype)
if bias.dtype != x.dtype:
bias = self.bias.to(x.dtype)
if self.use_apex_gn:
x = self.gn(x)
elif self.training:
if weight.dtype != x.dtype:
weight = self.weight.to(x.dtype)
if bias.dtype != x.dtype:
bias = self.bias.to(x.dtype)

if self.training:
# Use default torch implementation of GroupNorm for training
# This does not support channels last memory format
x = torch.nn.functional.group_norm(
Expand All @@ -400,8 +418,6 @@ def forward(self, x):
bias=bias,
eps=self.eps,
)
if self.fused_act:
x = self.act_fn(x)
else:
# Use custom GroupNorm implementation that supports channels last
# memory layout for inference
Expand All @@ -418,8 +434,8 @@ def forward(self, x):
bias = rearrange(bias, "c -> 1 c 1 1")
x = x * weight + bias

if self.fused_act:
x = self.act_fn(x)
if self.act_fn is not None:
x = self.act_fn(x)
return x

def get_activation_function(self):
Expand Down Expand Up @@ -583,11 +599,10 @@ def __init__(
self.adaptive_scale = adaptive_scale
self.profile_mode = profile_mode
self.amp_mode = amp_mode
self.norm0 = GroupNorm(
self.norm0 = get_group_norm(
num_channels=in_channels,
eps=eps,
use_apex_gn=use_apex_gn,
fused_act=True,
act=act,
amp_mode=amp_mode,
)
Expand All @@ -609,19 +624,18 @@ def __init__(
**init,
)
if self.adaptive_scale:
self.norm1 = GroupNorm(
self.norm1 = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
amp_mode=amp_mode,
)
else:
self.norm1 = GroupNorm(
self.norm1 = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
act=act,
fused_act=True,
amp_mode=amp_mode,
)
self.conv1 = Conv2d(
Expand Down Expand Up @@ -650,7 +664,7 @@ def __init__(
)

if self.num_heads:
self.norm2 = GroupNorm(
self.norm2 = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
Expand Down
8 changes: 5 additions & 3 deletions physicsnemo/models/diffusion/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from physicsnemo.models.diffusion import (
Conv2d,
FourierEmbedding,
GroupNorm,
Linear,
PositionalEmbedding,
UNetBlock,
get_group_norm,
)
from physicsnemo.models.meta import ModelMetaData
from physicsnemo.models.module import Module
Expand Down Expand Up @@ -483,7 +483,7 @@ def __init__(
resample_filter=resample_filter,
amp_mode=amp_mode,
)
self.dec[f"{res}x{res}_aux_norm"] = GroupNorm(
self.dec[f"{res}x{res}_aux_norm"] = get_group_norm(
num_channels=cout,
eps=1e-6,
use_apex_gn=use_apex_gn,
Expand Down Expand Up @@ -829,7 +829,9 @@ def __init__(
if self.gridtype == "learnable":
self.pos_embd = self._get_positional_embedding()
else:
self.register_buffer("pos_embd", self._get_positional_embedding().float())
self.register_buffer(
"pos_embd", self._get_positional_embedding().float(), persistent=False
)
self.lead_time_mode = lead_time_mode
if self.lead_time_mode:
self.lead_time_channels = lead_time_channels
Expand Down
Loading