Skip to content

Fix casting in SongUNetPosEmbd and shape in CorrDiff generation #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/weather/corrdiff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def generate_fn():
net=net_reg,
img_lr=img_lr,
latents_shape=(
cfg.generation.seed_batch_size,
sum(map(len, rank_batches)),
img_out_channels,
img_shape[0],
img_shape[1],
Expand Down
2 changes: 1 addition & 1 deletion physicsnemo/models/diffusion/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __init__(
self.act_fn = None
self.amp_mode = amp_mode
if self.use_apex_gn:
if self.act:
if self.fused_act:
self.gn = ApexGroupNorm(
num_groups=self.num_groups,
num_channels=num_channels,
Expand Down
55 changes: 27 additions & 28 deletions physicsnemo/models/diffusion/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,9 +893,6 @@ def forward(
"Cannot provide both embedding_selector and global_index."
)

if x.dtype != self.pos_embd.dtype:
self.pos_embd = self.pos_embd.to(x.dtype)

# Append positional embedding to input conditioning
if self.pos_embd is not None:
# Select positional embeddings with a selector function
Expand All @@ -909,22 +906,23 @@ def forward(
selected_pos_embd = self.positional_embedding_indexing(
x, global_index=global_index, lead_time_label=lead_time_label
)
x = torch.cat((x, selected_pos_embd), dim=1)
x = torch.cat((x, selected_pos_embd.to(x.dtype)), dim=1)

out = super().forward(x, noise_labels, class_labels, augment_labels)

if self.lead_time_mode:
if self.lead_time_mode and self.prob_channels:
# if training mode, let crossEntropyLoss do softmax. The model outputs logits.
# if eval mode, the model outputs probability
if self.prob_channels and out.dtype != self.scalar.dtype:
self.scalar.data = self.scalar.data.to(out.dtype)
if self.prob_channels and (not self.training):
out[:, self.prob_channels] = (
out[:, self.prob_channels] * self.scalar
).softmax(dim=1)
elif self.prob_channels and self.training:
scalar = self.scalar
if out.dtype != scalar.dtype:
scalar = scalar.to(out.dtype)
if self.training:
out[:, self.prob_channels] = out[:, self.prob_channels] * scalar
else:
out[:, self.prob_channels] = (
out[:, self.prob_channels] * self.scalar
(out[:, self.prob_channels] * scalar)
.softmax(dim=1)
.to(out.dtype)
)
return out

Expand Down Expand Up @@ -983,15 +981,16 @@ def positional_embedding_indexing(
"""
# If no global indices are provided, select all embeddings and expand
# to match the batch size of the input
if x.dtype != self.pos_embd.dtype:
self.pos_embd = self.pos_embd.to(x.dtype)
pos_embd = self.pos_embd
if x.dtype != pos_embd.dtype:
pos_embd = pos_embd.to(x.dtype)

if global_index is None:
if self.lead_time_mode:
selected_pos_embd = []
if self.pos_embd is not None:
if pos_embd is not None:
selected_pos_embd.append(
self.pos_embd[None].expand((x.shape[0], -1, -1, -1))
pos_embd[None].expand((x.shape[0], -1, -1, -1))
)
if self.lt_embd is not None:
selected_pos_embd.append(
Expand All @@ -1008,7 +1007,7 @@ def positional_embedding_indexing(
if len(selected_pos_embd) > 0:
selected_pos_embd = torch.cat(selected_pos_embd, dim=1)
else:
selected_pos_embd = self.pos_embd[None].expand(
selected_pos_embd = pos_embd[None].expand(
(x.shape[0], -1, -1, -1)
) # (B, C_{PE}, H, W)

Expand All @@ -1021,11 +1020,11 @@ def positional_embedding_indexing(
global_index = torch.reshape(
torch.permute(global_index, (1, 0, 2, 3)), (2, -1)
) # (P, 2, X, Y) to (2, P*X*Y)
selected_pos_embd = self.pos_embd[
selected_pos_embd = pos_embd[
:, global_index[0], global_index[1]
] # (C_pe, P*X*Y)
selected_pos_embd = torch.permute(
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)),
torch.reshape(selected_pos_embd, (pos_embd.shape[0], P, H, W)),
(1, 0, 2, 3),
) # (P, C_pe, X, Y)

Expand All @@ -1036,7 +1035,7 @@ def positional_embedding_indexing(
# Append positional and lead time embeddings to input conditioning
if self.lead_time_mode:
embeds = []
if self.pos_embd is not None:
if pos_embd is not None:
embeds.append(selected_pos_embd) # reuse code below
if self.lt_embd is not None:
lt_embds = self.lt_embd[
Expand Down Expand Up @@ -1122,15 +1121,15 @@ def positional_embedding_selector(
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
>>>
"""
if x.dtype != self.pos_embd.dtype:
self.pos_embd = self.pos_embd.to(x.dtype)
pos_embd = self.pos_embd
if x.dtype != pos_embd.dtype:
pos_embd = pos_embd.to(x.dtype)
if lead_time_label is not None:
# all patches share same lead_time_label
embeddings = torch.cat(
[self.pos_embd, self.lt_embd[lead_time_label[0].int()]]
)
# TODO: here we assume all patches share same lead_time_label -->
# need be changed
embeddings = torch.cat([pos_embd, self.lt_embd[lead_time_label[0].int()]])
else:
embeddings = self.pos_embd
embeddings = pos_embd
return embedding_selector(embeddings) # (B, N_pe, H, W)

def _get_positional_embedding(self):
Expand Down
3 changes: 2 additions & 1 deletion test/models/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# limitations under the License.

from .checkpoints import validate_checkpoint
from .fwdaccuracy import validate_forward_accuracy
from .fwdaccuracy import validate_forward_accuracy, validate_tensor_accuracy
from .inference import check_ort_version, validate_onnx_export, validate_onnx_runtime
from .optimization import (
torch_compile_model,
validate_amp,
validate_combo_optims,
validate_cuda_graphs,
Expand Down
58 changes: 58 additions & 0 deletions test/models/common/fwdaccuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,61 @@ def validate_forward_accuracy(
)

return compare_output(output, output_target, rtol, atol)


@torch.no_grad()
def validate_tensor_accuracy(
output: Tensor,
rtol: float = 1e-3,
atol: float = 1e-3,
file_name: Union[str, None] = None,
) -> bool:
"""Validates the accuracy of a tensor with a reference output

Parameters
----------
output : Tensor
Output tensor
rtol : float, optional
Relative tolerance of error allowed, by default 1e-3
atol : float, optional
Absolute tolerance of error allowed, by default 1e-3
file_name : Union[str, None], optional
Override the default file name of the stored target output, by default None

Returns
-------
bool
Test passed

Raises
------
IOError
Target output tensor file for this model was not found
"""
# File name / path
# Output files should live in test/utils/data

# Always use tuples for this comparison / saving
if isinstance(output, Tensor):
device = output.device
output = (output,)
else:
device = output[0].device

file_name = (
Path(__file__).parents[1].resolve() / Path("data") / Path(file_name.lower())
)
# If file does not exist, we will create it then error
# Model should then reproduce it on next pytest run
if not file_name.exists():
save_output(output, file_name)
raise IOError(
f"Output check file {str(file_name)} wasn't found so one was created. Please re-run the test."
)
# Load tensor dictionary and check
else:
tensor_dict = torch.load(str(file_name))
output_target = tuple([value.to(device) for value in tensor_dict.values()])

return compare_output(output, output_target, rtol, atol)
12 changes: 12 additions & 0 deletions test/models/common/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,18 @@ def forward(*args, **kwargs):
return forward


def torch_compile_model(
model: physicsnemo.Module, fullgraph: bool = True, error_on_recompile: bool = False
) -> physicsnemo.Module:
backend = (
nop_backend # for fast compilation for fx graph capture, use a nop backend
)
torch._dynamo.reset()
torch._dynamo.config.error_on_recompile = error_on_recompile
model = torch.compile(model, backend=backend, fullgraph=fullgraph)
return model


def validate_torch_compile(
model: physicsnemo.Module,
in_args: Tuple[Tensor] = (),
Expand Down
Binary file not shown.
Binary file not shown.
125 changes: 124 additions & 1 deletion test/models/diffusion/test_song_unet_pos_lt_embd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,73 @@
from physicsnemo.models.diffusion import SongUNetPosLtEmbd as UNet


def setup_model_learnable_embd(img_resolution, lt_steps, lt_channels, N_pos, seed=0):
"""
Create a model with similar architecture to CorrDiff (learnable positional
embeddings, self-attention, learnable lead time embeddings).
"""
# Smaller architecture variant with learnable positional embeddings
# (similar to CorrDiff example)
C_x, C_cond = 4, 3
attn_res = (
img_resolution[0] // 4
if isinstance(img_resolution, list) or isinstance(img_resolution, tuple)
else img_resolution // 4
)
torch.manual_seed(seed)
model = UNet(
img_resolution=img_resolution,
in_channels=C_x + N_pos + C_cond + lt_channels,
out_channels=C_x,
model_channels=16,
channel_mult=[1, 2, 2],
channel_mult_emb=2,
num_blocks=2,
attn_resolutions=[attn_res],
gridtype="learnable",
N_grid_channels=N_pos,
lead_time_steps=lt_steps,
lead_time_channels=lt_channels,
prob_channels=[1, 3],
)
return model


def generate_data_with_patches(H_p, W_p, device):
"""
Utility function to generate input data with patches in a consistent way
accross multiple tests.
"""
torch.manual_seed(0)
P, B, C_x, C_cond, lt_steps = 4, 3, 4, 3, 4
max_offset = 35
input_image = torch.randn([P * B, C_x + C_cond, H_p, W_p]).to(device)
noise_label = torch.randn([P * B]).to(device)
class_label = None
lead_time_label = torch.randint(0, lt_steps, (B,)).to(device)
base_grid = torch.stack(
torch.meshgrid(torch.arange(H_p), torch.arange(W_p), indexing="ij"), dim=0
)[None].to(device)
offset = torch.randint(0, max_offset, (P, 2))[:, :, None, None].to(device)
global_index = base_grid + offset
return input_image, noise_label, class_label, lead_time_label, global_index


def generate_data_no_patches(H, W, device):
"""
Utility function to generate input data without patches in a consistent way
accross multiple tests.
"""
torch.manual_seed(0)
B, C_x, C_cond, lt_steps = 3, 4, 3, 4
input_image = torch.randn([B, C_x + C_cond, H, W]).to(device)
noise_label = torch.randn([B]).to(device)
class_label = None
lead_time_label = torch.randint(0, lt_steps, (B,)).to(device)
global_index = None
return input_image, noise_label, class_label, lead_time_label, global_index


@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_song_unet_forward(device):
torch.manual_seed(0)
Expand Down Expand Up @@ -178,6 +245,62 @@ def test_song_unet_global_indexing(device):
assert torch.equal(pos_embed, global_index)


@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_song_unet_positional_embedding_indexing_no_patches(device):
"""
Test for positional_embedding_indexing method. Does not use patches (i.e.
input image is the entire global image).
"""

# Common parameters
B, lt_steps = 3, 4

# CorrDiff model with rectangular global shape
H, W = 128, 112
N_pos, lt_channels = 6, 8
model = (
setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos)
.to(device)
.to(memory_format=torch.channels_last)
)
inputs = generate_data_no_patches(H, W, device)
pos_embed = model.positional_embedding_indexing(inputs[0], inputs[4], inputs[3])
assert pos_embed.shape == (B, N_pos + lt_channels, H, W)
assert common.validate_tensor_accuracy(
pos_embed,
file_name="songunet_pos_lt_embd_pos_embed_indexing_no_patches_corrdiff.pth",
)
# TODO: add non-regression tests for other architectures


@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_song_unet_positional_embedding_indexing_with_patches(device):
"""
Test for positional_embedding_indexing method. Uses patches (i.e. input image
is only a subset of the global image).
"""

# Common parameters
P, B, H_p, W_p, lt_steps = 4, 3, 32, 64, 4

# CorrDiff model with rectangular global shape
H, W = 128, 112
N_pos, lt_channels = 6, 8
model = (
setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos)
.to(device)
.to(memory_format=torch.channels_last)
)
inputs = generate_data_with_patches(H_p, W_p, device)
pos_embed = model.positional_embedding_indexing(inputs[0], inputs[4], inputs[3])
assert pos_embed.shape == (P * B, N_pos + lt_channels, H_p, W_p)
assert common.validate_tensor_accuracy(
pos_embed,
file_name="songunet_pos_lt_embd_pos_embed_indexing_with_patches_corrdiff.pth",
)
# TODO: add non-regression tests for other architectures


@pytest.mark.parametrize("device", ["cuda:0", "cpu"])
def test_song_unet_embedding_selector(device):
torch.manual_seed(0)
Expand Down Expand Up @@ -382,7 +505,7 @@ def setup_model():
@pytest.mark.parametrize("device", ["cuda:0"])
def test_song_unet_checkpoint(device):
"""Test Song UNet checkpoint save/load"""
# Construct FNO models

model_1 = UNet(
img_resolution=16,
in_channels=6,
Expand Down
Loading