Skip to content

🐛[BUG]: GroupNorm creates unused parameters when use_apex_gn=True #1001

@akshaysubr

Description

@akshaysubr

Version

1.1.0

On which installation method(s) does this occur?

Pip

Describe the issue

The GroupNorm class in physicsnemo defines the affine parameters here: https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/nn/modules/normalization.py#L291-L292

But Apex defines those internally as well: https://github.com/NVIDIA/apex/blob/d8200a1bec8d29fe3f44d96b0a9e7b7fba850c77/apex/contrib/group_norm/group_norm.py#L321-L324

Gradients are only computed for the Apex defined weights and the physicsnemo defined weights are unused, so they are indeed unused parameters that shouldn’t be there in the first place. This causes the physicsnemo defined weights to have None gradients and as a result, DDP hangs waiting for an allreduce for those unused weight gradients. A workaround might is to use find_unused_parameters=True with DDP, but that is very unsatisfactory. There is another issue that a model checkpoint that was trained used use_apex_gn=False will produce wrong results if inferenced using use_apex_gn=True.

Minimum reproducible example

In [1]: from physicsnemo.models.diffusion.layers import GroupNorm

In [2]: import torch

In [3]: device = torch.device("cuda")

In [4]: m = GroupNorm(num_channels=128, eps=1e-6, use_apex_gn=True).to(device)

In [5]: inp = torch.ones(1, 128, 64, 64, device=device)

In [6]: optimizer = torch.optim.SGD(m.parameters(), lr=1e-4, momentum=0.9)

In [7]: out = m(inp)

In [9]: loss = out.mean()

In [10]: loss.backward()

In [11]: optimizer.step()

In [12]: for name, param in m.named_parameters():
    ...:     if param.grad is None:
    ...:         print(f"{name} has `None` gradients")
    ...:     else:
    ...:         print(f"{name} has valid gradients")
    ...: 
weight has `None` gradients
bias has `None` gradients
gn.weight has valid gradients
gn.bias has valid gradients

Relevant log output

Environment details

Metadata

Metadata

Assignees

No one assigned

    Labels

    ? - Needs TriageNeed team to review and classifybugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions