Skip to content

Refactor GroupNorm and log unmatched state_dict keys #989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

juliusberner
Copy link
Contributor

@juliusberner juliusberner commented Jun 24, 2025

PhysicsNeMo Pull Request

Description

  • Refactor GroupNorm and add get_group_norm to keep the state_dict consistent with previous versions.
  • Log missing and unexpected keys when loading checkpoints.
  • Add persistent=False for deterministic, non-learnable positional embeddings.

Closes #1001 .

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

@CharlelieLrt CharlelieLrt self-requested a review June 24, 2025 21:05
@CharlelieLrt CharlelieLrt added bug Something isn't working 3 - Ready for Review Ready for review by team labels Jun 24, 2025
@@ -461,8 +483,7 @@ def from_checkpoint(
local_path.joinpath("model.pt"), map_location=model.device
)

model_dict = convert_ckp_apex(ckp_args, model_args, model_dict)
model.load_state_dict(model_dict, strict=False)
load_state_dict_with_logging(model, model_dict, strict=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm generally uncomfortable with strict=False here. I realize it was injected in #809 rather than this PR, but since this is a widely used function affecting all trained models in physicsnemo, can we revert it to strict=True here? Seems like backwards-compat handling should be taken care of by the time this line is run. Thoughts @CharlelieLrt ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzharrington good point! I was actually wondering why we were setting strict=False here, but I thought it was there since the beginning and I've never realized it was introduced by #809 . We should definitely revert it back to True. AFAIK it's only useful when fine-tuning some parts of the model, or other things like this...

@juliusberner does your application requires strict=False there, or would it be okay to revert to True?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably also worth checking with @LostnEkko -- was it introduced in #809 as part of the apex checkpoint handling, and would you be able to test your use-case off of this PR to see if strict=True works?

Copy link
Contributor Author

@juliusberner juliusberner Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best option would be to expose strict (similar as in the load method above). We can then default it to True if we want --- if the user sets it to False, we would still log unexpected/missing keys with this PR (preventing silent errors which are happening right now since the checkpoint conversion has a bug in https://github.com/NVIDIA/physicsnemo/blob/d1c9391f0f594f7279c8990bff70b8227a6d1f93/physicsnemo/models/util_compatibility.py#L92C17-L92C50)

Copy link
Collaborator

@CharlelieLrt CharlelieLrt Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably also worth checking with @LostnEkko

@jialusui1102 for viz

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@juliusberner sounds good to me! Feel free to update your PR accordingly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CharlelieLrt I adapted it

might be adjusted to satisfy the `min_channels_per_group` condition.
"""

num_groups = min(num_groups, num_channels // min_channels_per_group)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should include the fix from #996 here as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I rebased and included it now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review Ready for review by team bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🐛[BUG]: GroupNorm creates unused parameters when use_apex_gn=True
4 participants