-
Notifications
You must be signed in to change notification settings - Fork 393
Refactor GroupNorm and log unmatched state_dict keys #989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
physicsnemo/models/module.py
Outdated
@@ -461,8 +483,7 @@ def from_checkpoint( | |||
local_path.joinpath("model.pt"), map_location=model.device | |||
) | |||
|
|||
model_dict = convert_ckp_apex(ckp_args, model_args, model_dict) | |||
model.load_state_dict(model_dict, strict=False) | |||
load_state_dict_with_logging(model, model_dict, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm generally uncomfortable with strict=False
here. I realize it was injected in #809 rather than this PR, but since this is a widely used function affecting all trained models in physicsnemo, can we revert it to strict=True
here? Seems like backwards-compat handling should be taken care of by the time this line is run. Thoughts @CharlelieLrt ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pzharrington good point! I was actually wondering why we were setting strict=False
here, but I thought it was there since the beginning and I've never realized it was introduced by #809 . We should definitely revert it back to True
. AFAIK it's only useful when fine-tuning some parts of the model, or other things like this...
@juliusberner does your application requires strict=False
there, or would it be okay to revert to True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably also worth checking with @LostnEkko -- was it introduced in #809 as part of the apex checkpoint handling, and would you be able to test your use-case off of this PR to see if strict=True
works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the best option would be to expose strict
(similar as in the load
method above). We can then default it to True
if we want --- if the user sets it to False
, we would still log unexpected/missing keys with this PR (preventing silent errors which are happening right now since the checkpoint conversion has a bug in https://github.com/NVIDIA/physicsnemo/blob/d1c9391f0f594f7279c8990bff70b8227a6d1f93/physicsnemo/models/util_compatibility.py#L92C17-L92C50)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably also worth checking with @LostnEkko
@jialusui1102 for viz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@juliusberner sounds good to me! Feel free to update your PR accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CharlelieLrt I adapted it
might be adjusted to satisfy the `min_channels_per_group` condition. | ||
""" | ||
|
||
num_groups = min(num_groups, num_channels // min_channels_per_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should include the fix from #996 here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I rebased and included it now!
Signed-off-by: Julius Berner <[email protected]>
8767605
to
f8e01c7
Compare
PhysicsNeMo Pull Request
Description
GroupNorm
and addget_group_norm
to keep thestate_dict
consistent with previous versions.persistent=False
for deterministic, non-learnable positional embeddings.Closes #1001 .
Checklist