Skip to content

Conversation

@CharlelieLrt
Copy link
Collaborator

PhysicsNeMo Pull Request

Description

Fixes inconsistencies with the shape of non-square imnages in the CorrDiff example.

Closes #1011 .

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR fixes shape inconsistencies for non-square images in the CorrDiff weather forecasting example. The issue occurred when img_shape_ximg_shape_y due to confusion between width/height semantics and numpy's (H,W) convention.

Key changes:

  • Clarified configuration: Added comments to cwb.yaml specifying that img_shape_x represents domain width and img_shape_y represents domain height
  • Fixed image_shape() return order: Changed from (img_shape_x, img_shape_y) to (img_shape_y, img_shape_x) to match numpy/torch convention of (height, width)
  • Corrected rolling logic: Renamed y_roll to x_roll and fixed the random range from [0, img_shape_y] to [0, img_shape_x] since the roll operation occurs along the horizontal axis
  • Updated tensor operations: Fixed slicing and reshaping in img_utils.py to use the correct dimension ordering throughout

The fix ensures that when users specify non-square dimensions (e.g., 320×336), the resulting tensors have the correct shape (C, 336, 320) instead of the incorrect (C, 320, 336), properly matching the intended height×width configuration.

Confidence Score: 5/5

  • This PR is safe to merge with high confidence - it fixes a clear shape inconsistency bug with systematic, well-reasoned changes
  • The changes are well-structured and address a genuine bug where parameter names didn't match numpy conventions. All modifications are internally consistent: the roll variable is correctly renamed and uses the right dimension, image_shape() now returns (H,W) as expected by downstream code, and tensor operations maintain proper dimension ordering. The fix aligns with how the data is stored (south_north, west_east) and how other parts of the codebase use img_shape[0] for height and img_shape[1] for width.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
examples/weather/corrdiff/conf/base/dataset/cwb.yaml 5/5 Added clarifying comments to distinguish width vs height parameters
examples/weather/corrdiff/datasets/cwb.py 5/5 Fixed image_shape() return order to match numpy (H,W) convention and corrected roll variable naming/range
examples/weather/corrdiff/datasets/img_utils.py 5/5 Updated slicing and reshaping operations to use correct dimension ordering (H,W)

Sequence Diagram

sequenceDiagram
    participant Config as cwb.yaml Config
    participant Dataset as ZarrDataset
    participant Utils as img_utils.reshape_fields
    participant Model as Training/Generation

    Note over Config: img_shape_x=320 (width)<br/>img_shape_y=336 (height)
    
    Config->>Dataset: Initialize with img_shape_x, img_shape_y
    
    Dataset->>Dataset: __getitem__(idx)
    Note over Dataset: Generate x_roll in [0, img_shape_x]<br/>(was y_roll in [0, img_shape_y])
    
    Dataset->>Utils: reshape_fields(img, x_roll, ...)
    Note over Utils: Slice to [:img_shape_y, :img_shape_x]<br/>(was [:img_shape_x, :img_shape_y])
    
    Utils->>Utils: np.roll(img, x_roll, axis=-1)
    Note over Utils: Roll along X dimension
    
    Utils->>Utils: reshape to (C, img_shape_y, img_shape_x)
    Note over Utils: Output shape (C, H, W)<br/>(was (C, W, H))
    
    Utils-->>Dataset: Return tensor with correct shape
    
    Dataset->>Model: image_shape() = (img_shape_y, img_shape_x)
    Note over Model: Returns (height, width)<br/>(was (width, height))
    
    Model->>Model: Use img_shape[0] for height<br/>Use img_shape[1] for width
Loading

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@CharlelieLrt
Copy link
Collaborator Author

/blossom-ci

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR fixes shape inconsistencies for non-square images in the CorrDiff dataset implementation, resolving issue #1011 where tensor dimensions were swapped.

Key Changes:

  • Configuration: Added clarifying comments to cwb.yaml documenting that img_shape_x represents domain width and img_shape_y represents domain height
  • Dataset (cwb.py):
    • Fixed image_shape() to return (img_shape_y, img_shape_x) instead of (img_shape_x, img_shape_y), correctly following NumPy's (height, width) convention
    • Corrected rolling logic by renaming y_roll to x_roll and using img_shape_x for the range, fixing semantic mismatch where the old code generated y-dimension roll amounts but applied them to the x-axis
  • Image utilities (img_utils.py): Updated array slicing, reshaping, and rolling to use (img_shape_y, img_shape_x) ordering throughout, ensuring consistency with the (height, width) convention

Impact:
For non-square images (e.g., 320×336), tensors now correctly have shape (channels, 336, 320) instead of (channels, 320, 336), where 336 is the height (img_shape_y) and 320 is the width (img_shape_x). This aligns the tensor shapes with the configuration parameters and NumPy conventions.

Confidence Score: 4/5

  • This PR is safe to merge with one minor concern about the pre-existing logic comment
  • The changes correctly fix the shape inconsistency by aligning with NumPy's (height, width) convention throughout the codebase. The fix addresses the reported issue and maintains internal consistency. Score is 4 rather than 5 due to the logic comment flagged - while not introduced by this PR, it highlights an area that should be verified
  • examples/weather/corrdiff/datasets/cwb.py - verify the latitude/longitude slicing logic is correct for non-square domains in testing

Important Files Changed

File Analysis

Filename Score Overview
examples/weather/corrdiff/conf/base/dataset/cwb.yaml 5/5 Added clarifying comments for img_shape_x (width) and img_shape_y (height) to document the dimension meanings
examples/weather/corrdiff/datasets/cwb.py 4/5 Fixed rolling logic by renaming y_roll to x_roll and using img_shape_x for range, correcting semantic mismatch with axis=-1 (width). Updated image_shape() to return (img_shape_y, img_shape_x) matching NumPy (height, width) convention
examples/weather/corrdiff/datasets/img_utils.py 4/5 Updated parameter name y_roll to x_roll, fixed array slicing and reshaping to use (img_shape_y, img_shape_x) order consistent with NumPy (height, width) convention, updated docstring

Sequence Diagram

sequenceDiagram
    participant Config as cwb.yaml Config
    participant Dataset as ZarrDataset
    participant ImgUtils as reshape_fields()
    participant Model as Training/Generation
    
    Note over Config: img_shape_x: 320 (width)<br/>img_shape_y: 336 (height)
    
    Config->>Dataset: Initialize with img_shape_x, img_shape_y
    
    rect rgb(240, 248, 255)
        Note over Dataset: __getitem__() called
        Dataset->>Dataset: Generate x_roll = random(0, img_shape_x)
        Note over Dataset: Rolling along x-axis (width)<br/>using x_roll amount
        
        Dataset->>ImgUtils: reshape_fields(img, x_roll, img_shape_x, img_shape_y, ...)
        
        Note over ImgUtils: Input: (n_history+1, c, h, w)
        ImgUtils->>ImgUtils: Slice: img[:, :, :img_shape_y, :img_shape_x]<br/>(height, width) order
        ImgUtils->>ImgUtils: Roll: np.roll(img, x_roll, axis=-1)<br/>Roll along width dimension
        ImgUtils->>ImgUtils: Reshape: (c*(n_history+1), img_shape_y, img_shape_x)<br/>(channels, height, width)
        
        ImgUtils-->>Dataset: Return torch tensor (C, H, W)
    end
    
    Dataset->>Dataset: image_shape() returns (img_shape_y, img_shape_x)
    Note over Dataset: Returns (height, width)<br/>following NumPy convention
    
    Dataset-->>Model: Tensors with shape (C, 336, 320)
    Note over Model: Correct shape: (channels, height, width)<br/>Matches img_shape_y=336, img_shape_x=320
Loading

Additional Comments (1)

  1. examples/weather/corrdiff/datasets/cwb.py, line 458-463 (link)

    logic: slicing order in longitude() and latitude() inconsistent with new convention - should be [:img_shape_y, :img_shape_x] not [:, :img_shape_y, :img_shape_x]

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@CharlelieLrt CharlelieLrt merged commit 219ed0d into NVIDIA:main Oct 31, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Questions about corrdiff shape_x and shape_y

2 participants