Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/weather/corrdiff/conf/base/dataset/cwb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ in_channels: [0, 1, 2, 3, 4, 9, 10, 11, 12, 17, 18, 19]
# Indices of output channels
out_channels: [0, 1, 2, 3]
# Shape of the image
img_shape_x: 448
img_shape_y: 448
img_shape_x: 448 # domain width
img_shape_y: 448 # domain height
# Add grid coordinates to the image
add_grid: true
# Factor to downscale the image
Expand Down
8 changes: 4 additions & 4 deletions examples/weather/corrdiff/datasets/cwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ def __getitem__(self, idx):
# crop and downsamples
# rolling
if self.train and self.roll:
y_roll = random.randint(0, self.img_shape_y)
x_roll = random.randint(0, self.img_shape_x)
else:
y_roll = 0
x_roll = 0

# channels
input = input[self.in_channels, :, :]
Expand All @@ -411,7 +411,7 @@ def __getitem__(self, idx):
target = self._create_lowres_(target, factor=self.ds_factor)

reshape_args = (
y_roll,
x_roll,
self.train,
self.n_history,
self.in_channels,
Expand Down Expand Up @@ -468,7 +468,7 @@ def time(self):

def image_shape(self):
"""Get the shape of the image (same for input and output)."""
return (self.img_shape_x, self.img_shape_y)
return (self.img_shape_y, self.img_shape_x)

def normalize_input(self, x):
"""Convert input from physical units to normalized data."""
Expand Down
12 changes: 6 additions & 6 deletions examples/weather/corrdiff/datasets/img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
def reshape_fields(
img,
inp_or_tar,
y_roll,
x_roll,
train,
n_history,
in_channels,
Expand All @@ -39,7 +39,7 @@ def reshape_fields(
):
"""
Takes in np array of size (n_history+1, c, h, w) and returns torch tensor of
size ((n_channels*(n_history+1), img_shape_x, img_shape_y)
size ((n_channels*(n_history+1), img_shape_y, img_shape_x)
"""

if len(np.shape(img)) == 3:
Expand All @@ -59,7 +59,7 @@ def reshape_fields(
means = np.load(global_means_path)[:, channels]
stds = np.load(global_stds_path)[:, channels]

img = img[:, :, :img_shape_x, :img_shape_y]
img = img[:, :, :img_shape_y, :img_shape_x]

if normalize and train:
if normalization == "minmax":
Expand All @@ -70,11 +70,11 @@ def reshape_fields(
img /= stds

if roll:
img = np.roll(img, y_roll, axis=-1)
img = np.roll(img, x_roll, axis=-1)

if inp_or_tar == "inp":
img = np.reshape(img, (n_channels * (n_history + 1), img_shape_x, img_shape_y))
img = np.reshape(img, (n_channels * (n_history + 1), img_shape_y, img_shape_x))
elif inp_or_tar == "tar":
img = np.reshape(img, (n_channels, img_shape_x, img_shape_y))
img = np.reshape(img, (n_channels, img_shape_y, img_shape_x))

return torch.as_tensor(img)