From 7481391952a808458de40f06d8e75e73a98555ad Mon Sep 17 00:00:00 2001 From: Charlelie Laurent Date: Fri, 31 Oct 2025 11:31:03 -0700 Subject: [PATCH] Fixed inconsistency for shapes of non-square images Signed-off-by: Charlelie Laurent --- examples/weather/corrdiff/conf/base/dataset/cwb.yaml | 4 ++-- examples/weather/corrdiff/datasets/cwb.py | 8 ++++---- examples/weather/corrdiff/datasets/img_utils.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/weather/corrdiff/conf/base/dataset/cwb.yaml b/examples/weather/corrdiff/conf/base/dataset/cwb.yaml index 5ec2b73601..78f55f0c52 100644 --- a/examples/weather/corrdiff/conf/base/dataset/cwb.yaml +++ b/examples/weather/corrdiff/conf/base/dataset/cwb.yaml @@ -23,8 +23,8 @@ in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19] # Indices of output channels out_channels: [0, 1, 2, 3] # Shape of the image -img_shape_x: 448 -img_shape_y: 448 +img_shape_x: 448 # domain width +img_shape_y: 448 # domain height # Add grid coordinates to the image add_grid: true # Factor to downscale the image diff --git a/examples/weather/corrdiff/datasets/cwb.py b/examples/weather/corrdiff/datasets/cwb.py index 9a35f917cb..ad20a2826a 100644 --- a/examples/weather/corrdiff/datasets/cwb.py +++ b/examples/weather/corrdiff/datasets/cwb.py @@ -399,9 +399,9 @@ def __getitem__(self, idx): # crop and downsamples # rolling if self.train and self.roll: - y_roll = random.randint(0, self.img_shape_y) + x_roll = random.randint(0, self.img_shape_x) else: - y_roll = 0 + x_roll = 0 # channels input = input[self.in_channels, :, :] @@ -411,7 +411,7 @@ def __getitem__(self, idx): target = self._create_lowres_(target, factor=self.ds_factor) reshape_args = ( - y_roll, + x_roll, self.train, self.n_history, self.in_channels, @@ -468,7 +468,7 @@ def time(self): def image_shape(self): """Get the shape of the image (same for input and output).""" - return (self.img_shape_x, self.img_shape_y) + return (self.img_shape_y, self.img_shape_x) def normalize_input(self, x): """Convert input from physical units to normalized data.""" diff --git a/examples/weather/corrdiff/datasets/img_utils.py b/examples/weather/corrdiff/datasets/img_utils.py index f53dc81724..5bce1ab962 100644 --- a/examples/weather/corrdiff/datasets/img_utils.py +++ b/examples/weather/corrdiff/datasets/img_utils.py @@ -22,7 +22,7 @@ def reshape_fields( img, inp_or_tar, - y_roll, + x_roll, train, n_history, in_channels, @@ -39,7 +39,7 @@ def reshape_fields( ): """ Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of - size ((n_channels*(n_history+1), img_shape_x, img_shape_y) + size ((n_channels*(n_history+1), img_shape_y, img_shape_x) """ if len(np.shape(img)) == 3: @@ -59,7 +59,7 @@ def reshape_fields( means = np.load(global_means_path)[:, channels] stds = np.load(global_stds_path)[:, channels] - img = img[:, :, :img_shape_x, :img_shape_y] + img = img[:, :, :img_shape_y, :img_shape_x] if normalize and train: if normalization == "minmax": @@ -70,11 +70,11 @@ def reshape_fields( img /= stds if roll: - img = np.roll(img, y_roll, axis=-1) + img = np.roll(img, x_roll, axis=-1) if inp_or_tar == "inp": - img = np.reshape(img, (n_channels * (n_history + 1), img_shape_x, img_shape_y)) + img = np.reshape(img, (n_channels * (n_history + 1), img_shape_y, img_shape_x)) elif inp_or_tar == "tar": - img = np.reshape(img, (n_channels, img_shape_x, img_shape_y)) + img = np.reshape(img, (n_channels, img_shape_y, img_shape_x)) return torch.as_tensor(img)