-
Notifications
You must be signed in to change notification settings - Fork 393
Solve checkpoint and validation bugs in training corrdiff unet #1000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…y script as the regression loss does not support it
@CharlelieLrt Can you please review it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just left a suggestion to make the checkpointing more robust for high/low aspect ratio images.
…me dimension used to name the layers
idx_x = torch.arange(self.img_shape_y) | ||
idx_y = torch.arange(self.img_shape_x) | ||
mesh_x, mesh_y = torch.meshgrid(idx_x, idx_y) | ||
grid = torch.stack((mesh_x, mesh_y), dim=0) # (2, img_shape_y, img_shape_x) | ||
idx_x = torch.arange(self.img_shape_x) | ||
idx_y = torch.arange(self.img_shape_y) | ||
mesh_y, mesh_x = torch.meshgrid(idx_y, idx_x) | ||
grid = torch.stack((mesh_y, mesh_x), dim=0) # (2, img_shape_y, img_shape_x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your changes are consistent with the original implementation, but I think the original implementation is wrong.
If you look at the shape of grid, I think it is transposed in comparison to the case self.gridtype == "linear"
.
The case self.gridtype == "test"
is only used in CI tests, but it might be better to fix the implementation + fix any test that is failing (probably not many because we have only very few tests for non-square images)
Closes #994 .