diff --git a/examples/weather/corrdiff/generate.py b/examples/weather/corrdiff/generate.py index a155904d7b..3c7e8676e5 100644 --- a/examples/weather/corrdiff/generate.py +++ b/examples/weather/corrdiff/generate.py @@ -208,7 +208,7 @@ def generate_fn(): net=net_reg, img_lr=img_lr, latents_shape=( - cfg.generation.seed_batch_size, + sum(map(len, rank_batches)), img_out_channels, img_shape[0], img_shape[1], diff --git a/physicsnemo/models/diffusion/layers.py b/physicsnemo/models/diffusion/layers.py index b93c2889bf..b41f279096 100644 --- a/physicsnemo/models/diffusion/layers.py +++ b/physicsnemo/models/diffusion/layers.py @@ -393,7 +393,7 @@ def __init__( self.act_fn = None self.amp_mode = amp_mode if self.use_apex_gn: - if self.act: + if self.fused_act: self.gn = ApexGroupNorm( num_groups=self.num_groups, num_channels=num_channels, diff --git a/physicsnemo/models/diffusion/song_unet.py b/physicsnemo/models/diffusion/song_unet.py index 326d1b11c5..14893e749b 100644 --- a/physicsnemo/models/diffusion/song_unet.py +++ b/physicsnemo/models/diffusion/song_unet.py @@ -893,9 +893,6 @@ def forward( "Cannot provide both embedding_selector and global_index." ) - if x.dtype != self.pos_embd.dtype: - self.pos_embd = self.pos_embd.to(x.dtype) - # Append positional embedding to input conditioning if self.pos_embd is not None: # Select positional embeddings with a selector function @@ -909,22 +906,23 @@ def forward( selected_pos_embd = self.positional_embedding_indexing( x, global_index=global_index, lead_time_label=lead_time_label ) - x = torch.cat((x, selected_pos_embd), dim=1) + x = torch.cat((x, selected_pos_embd.to(x.dtype)), dim=1) out = super().forward(x, noise_labels, class_labels, augment_labels) - if self.lead_time_mode: + if self.lead_time_mode and self.prob_channels: # if training mode, let crossEntropyLoss do softmax. The model outputs logits. # if eval mode, the model outputs probability - if self.prob_channels and out.dtype != self.scalar.dtype: - self.scalar.data = self.scalar.data.to(out.dtype) - if self.prob_channels and (not self.training): - out[:, self.prob_channels] = ( - out[:, self.prob_channels] * self.scalar - ).softmax(dim=1) - elif self.prob_channels and self.training: + scalar = self.scalar + if out.dtype != scalar.dtype: + scalar = scalar.to(out.dtype) + if self.training: + out[:, self.prob_channels] = out[:, self.prob_channels] * scalar + else: out[:, self.prob_channels] = ( - out[:, self.prob_channels] * self.scalar + (out[:, self.prob_channels] * scalar) + .softmax(dim=1) + .to(out.dtype) ) return out @@ -983,15 +981,16 @@ def positional_embedding_indexing( """ # If no global indices are provided, select all embeddings and expand # to match the batch size of the input - if x.dtype != self.pos_embd.dtype: - self.pos_embd = self.pos_embd.to(x.dtype) + pos_embd = self.pos_embd + if x.dtype != pos_embd.dtype: + pos_embd = pos_embd.to(x.dtype) if global_index is None: if self.lead_time_mode: selected_pos_embd = [] - if self.pos_embd is not None: + if pos_embd is not None: selected_pos_embd.append( - self.pos_embd[None].expand((x.shape[0], -1, -1, -1)) + pos_embd[None].expand((x.shape[0], -1, -1, -1)) ) if self.lt_embd is not None: selected_pos_embd.append( @@ -1008,7 +1007,7 @@ def positional_embedding_indexing( if len(selected_pos_embd) > 0: selected_pos_embd = torch.cat(selected_pos_embd, dim=1) else: - selected_pos_embd = self.pos_embd[None].expand( + selected_pos_embd = pos_embd[None].expand( (x.shape[0], -1, -1, -1) ) # (B, C_{PE}, H, W) @@ -1021,11 +1020,11 @@ def positional_embedding_indexing( global_index = torch.reshape( torch.permute(global_index, (1, 0, 2, 3)), (2, -1) ) # (P, 2, X, Y) to (2, P*X*Y) - selected_pos_embd = self.pos_embd[ + selected_pos_embd = pos_embd[ :, global_index[0], global_index[1] ] # (C_pe, P*X*Y) selected_pos_embd = torch.permute( - torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), + torch.reshape(selected_pos_embd, (pos_embd.shape[0], P, H, W)), (1, 0, 2, 3), ) # (P, C_pe, X, Y) @@ -1036,7 +1035,7 @@ def positional_embedding_indexing( # Append positional and lead time embeddings to input conditioning if self.lead_time_mode: embeds = [] - if self.pos_embd is not None: + if pos_embd is not None: embeds.append(selected_pos_embd) # reuse code below if self.lt_embd is not None: lt_embds = self.lt_embd[ @@ -1122,15 +1121,15 @@ def positional_embedding_selector( ... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) >>> """ - if x.dtype != self.pos_embd.dtype: - self.pos_embd = self.pos_embd.to(x.dtype) + pos_embd = self.pos_embd + if x.dtype != pos_embd.dtype: + pos_embd = pos_embd.to(x.dtype) if lead_time_label is not None: - # all patches share same lead_time_label - embeddings = torch.cat( - [self.pos_embd, self.lt_embd[lead_time_label[0].int()]] - ) + # TODO: here we assume all patches share same lead_time_label --> + # need be changed + embeddings = torch.cat([pos_embd, self.lt_embd[lead_time_label[0].int()]]) else: - embeddings = self.pos_embd + embeddings = pos_embd return embedding_selector(embeddings) # (B, N_pe, H, W) def _get_positional_embedding(self): diff --git a/test/models/common/__init__.py b/test/models/common/__init__.py index b732b1977e..eb02be2cc6 100644 --- a/test/models/common/__init__.py +++ b/test/models/common/__init__.py @@ -15,9 +15,10 @@ # limitations under the License. from .checkpoints import validate_checkpoint -from .fwdaccuracy import validate_forward_accuracy +from .fwdaccuracy import validate_forward_accuracy, validate_tensor_accuracy from .inference import check_ort_version, validate_onnx_export, validate_onnx_runtime from .optimization import ( + torch_compile_model, validate_amp, validate_combo_optims, validate_cuda_graphs, diff --git a/test/models/common/fwdaccuracy.py b/test/models/common/fwdaccuracy.py index 77c3b1470a..b63ae9dd97 100644 --- a/test/models/common/fwdaccuracy.py +++ b/test/models/common/fwdaccuracy.py @@ -131,3 +131,61 @@ def validate_forward_accuracy( ) return compare_output(output, output_target, rtol, atol) + + +@torch.no_grad() +def validate_tensor_accuracy( + output: Tensor, + rtol: float = 1e-3, + atol: float = 1e-3, + file_name: Union[str, None] = None, +) -> bool: + """Validates the accuracy of a tensor with a reference output + + Parameters + ---------- + output : Tensor + Output tensor + rtol : float, optional + Relative tolerance of error allowed, by default 1e-3 + atol : float, optional + Absolute tolerance of error allowed, by default 1e-3 + file_name : Union[str, None], optional + Override the default file name of the stored target output, by default None + + Returns + ------- + bool + Test passed + + Raises + ------ + IOError + Target output tensor file for this model was not found + """ + # File name / path + # Output files should live in test/utils/data + + # Always use tuples for this comparison / saving + if isinstance(output, Tensor): + device = output.device + output = (output,) + else: + device = output[0].device + + file_name = ( + Path(__file__).parents[1].resolve() / Path("data") / Path(file_name.lower()) + ) + # If file does not exist, we will create it then error + # Model should then reproduce it on next pytest run + if not file_name.exists(): + save_output(output, file_name) + raise IOError( + f"Output check file {str(file_name)} wasn't found so one was created. Please re-run the test." + ) + # Load tensor dictionary and check + else: + tensor_dict = torch.load(str(file_name)) + output_target = tuple([value.to(device) for value in tensor_dict.values()]) + + return compare_output(output, output_target, rtol, atol) diff --git a/test/models/common/optimization.py b/test/models/common/optimization.py index 5e693245b7..1bfeb0ffa3 100644 --- a/test/models/common/optimization.py +++ b/test/models/common/optimization.py @@ -204,6 +204,18 @@ def forward(*args, **kwargs): return forward +def torch_compile_model( + model: physicsnemo.Module, fullgraph: bool = True, error_on_recompile: bool = False +) -> physicsnemo.Module: + backend = ( + nop_backend # for fast compilation for fx graph capture, use a nop backend + ) + torch._dynamo.reset() + torch._dynamo.config.error_on_recompile = error_on_recompile + model = torch.compile(model, backend=backend, fullgraph=fullgraph) + return model + + def validate_torch_compile( model: physicsnemo.Module, in_args: Tuple[Tensor] = (), diff --git a/test/models/data/songunet_pos_lt_embd_pos_embed_indexing_no_patches_corrdiff.pth b/test/models/data/songunet_pos_lt_embd_pos_embed_indexing_no_patches_corrdiff.pth new file mode 100644 index 0000000000..2aa4471f99 Binary files /dev/null and b/test/models/data/songunet_pos_lt_embd_pos_embed_indexing_no_patches_corrdiff.pth differ diff --git a/test/models/data/songunet_pos_lt_embd_pos_embed_indexing_with_patches_corrdiff.pth b/test/models/data/songunet_pos_lt_embd_pos_embed_indexing_with_patches_corrdiff.pth new file mode 100644 index 0000000000..44e2aa484d Binary files /dev/null and b/test/models/data/songunet_pos_lt_embd_pos_embed_indexing_with_patches_corrdiff.pth differ diff --git a/test/models/diffusion/test_song_unet_pos_lt_embd.py b/test/models/diffusion/test_song_unet_pos_lt_embd.py index fc9579a071..1bff7a2c8f 100644 --- a/test/models/diffusion/test_song_unet_pos_lt_embd.py +++ b/test/models/diffusion/test_song_unet_pos_lt_embd.py @@ -28,6 +28,73 @@ from physicsnemo.models.diffusion import SongUNetPosLtEmbd as UNet +def setup_model_learnable_embd(img_resolution, lt_steps, lt_channels, N_pos, seed=0): + """ + Create a model with similar architecture to CorrDiff (learnable positional + embeddings, self-attention, learnable lead time embeddings). + """ + # Smaller architecture variant with learnable positional embeddings + # (similar to CorrDiff example) + C_x, C_cond = 4, 3 + attn_res = ( + img_resolution[0] // 4 + if isinstance(img_resolution, list) or isinstance(img_resolution, tuple) + else img_resolution // 4 + ) + torch.manual_seed(seed) + model = UNet( + img_resolution=img_resolution, + in_channels=C_x + N_pos + C_cond + lt_channels, + out_channels=C_x, + model_channels=16, + channel_mult=[1, 2, 2], + channel_mult_emb=2, + num_blocks=2, + attn_resolutions=[attn_res], + gridtype="learnable", + N_grid_channels=N_pos, + lead_time_steps=lt_steps, + lead_time_channels=lt_channels, + prob_channels=[1, 3], + ) + return model + + +def generate_data_with_patches(H_p, W_p, device): + """ + Utility function to generate input data with patches in a consistent way + accross multiple tests. + """ + torch.manual_seed(0) + P, B, C_x, C_cond, lt_steps = 4, 3, 4, 3, 4 + max_offset = 35 + input_image = torch.randn([P * B, C_x + C_cond, H_p, W_p]).to(device) + noise_label = torch.randn([P * B]).to(device) + class_label = None + lead_time_label = torch.randint(0, lt_steps, (B,)).to(device) + base_grid = torch.stack( + torch.meshgrid(torch.arange(H_p), torch.arange(W_p), indexing="ij"), dim=0 + )[None].to(device) + offset = torch.randint(0, max_offset, (P, 2))[:, :, None, None].to(device) + global_index = base_grid + offset + return input_image, noise_label, class_label, lead_time_label, global_index + + +def generate_data_no_patches(H, W, device): + """ + Utility function to generate input data without patches in a consistent way + accross multiple tests. + """ + torch.manual_seed(0) + B, C_x, C_cond, lt_steps = 3, 4, 3, 4 + input_image = torch.randn([B, C_x + C_cond, H, W]).to(device) + noise_label = torch.randn([B]).to(device) + class_label = None + lead_time_label = torch.randint(0, lt_steps, (B,)).to(device) + global_index = None + return input_image, noise_label, class_label, lead_time_label, global_index + + @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) def test_song_unet_forward(device): torch.manual_seed(0) @@ -178,6 +245,62 @@ def test_song_unet_global_indexing(device): assert torch.equal(pos_embed, global_index) +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_positional_embedding_indexing_no_patches(device): + """ + Test for positional_embedding_indexing method. Does not use patches (i.e. + input image is the entire global image). + """ + + # Common parameters + B, lt_steps = 3, 4 + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_channels = 6, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + inputs = generate_data_no_patches(H, W, device) + pos_embed = model.positional_embedding_indexing(inputs[0], inputs[4], inputs[3]) + assert pos_embed.shape == (B, N_pos + lt_channels, H, W) + assert common.validate_tensor_accuracy( + pos_embed, + file_name="songunet_pos_lt_embd_pos_embed_indexing_no_patches_corrdiff.pth", + ) + # TODO: add non-regression tests for other architectures + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +def test_song_unet_positional_embedding_indexing_with_patches(device): + """ + Test for positional_embedding_indexing method. Uses patches (i.e. input image + is only a subset of the global image). + """ + + # Common parameters + P, B, H_p, W_p, lt_steps = 4, 3, 32, 64, 4 + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_channels = 6, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + inputs = generate_data_with_patches(H_p, W_p, device) + pos_embed = model.positional_embedding_indexing(inputs[0], inputs[4], inputs[3]) + assert pos_embed.shape == (P * B, N_pos + lt_channels, H_p, W_p) + assert common.validate_tensor_accuracy( + pos_embed, + file_name="songunet_pos_lt_embd_pos_embed_indexing_with_patches_corrdiff.pth", + ) + # TODO: add non-regression tests for other architectures + + @pytest.mark.parametrize("device", ["cuda:0", "cpu"]) def test_song_unet_embedding_selector(device): torch.manual_seed(0) @@ -382,7 +505,7 @@ def setup_model(): @pytest.mark.parametrize("device", ["cuda:0"]) def test_song_unet_checkpoint(device): """Test Song UNet checkpoint save/load""" - # Construct FNO models + model_1 = UNet( img_resolution=16, in_channels=6, diff --git a/test/models/diffusion/test_song_unet_pos_lt_embd_agn_amp.py b/test/models/diffusion/test_song_unet_pos_lt_embd_agn_amp.py new file mode 100644 index 0000000000..f417e31b6a --- /dev/null +++ b/test/models/diffusion/test_song_unet_pos_lt_embd_agn_amp.py @@ -0,0 +1,604 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ruff: noqa: E402 +import os +import sys + +import pytest +import torch + +script_path = os.path.abspath(__file__) +sys.path.append(os.path.join(os.path.dirname(script_path), "..")) + +import common + +from physicsnemo.models.diffusion import SongUNetPosLtEmbd + + +def setup_model_learnable_embd(img_resolution, lt_steps, lt_channels, N_pos, seed=0): + """ + Create a model with similar architecture to CorrDiff (learnable positional + embeddings, self-attention, learnable lead time embeddings). + """ + # Smaller architecture variant with learnable positional embeddings + # (similar to CorrDiff example) + C_x, C_cond = 4, 3 + attn_res = ( + img_resolution[0] // 4 + if isinstance(img_resolution, list) or isinstance(img_resolution, tuple) + else img_resolution // 4 + ) + torch.manual_seed(seed) + model = SongUNetPosLtEmbd( + img_resolution=img_resolution, + in_channels=C_x + N_pos + C_cond + lt_channels, + out_channels=C_x, + model_channels=16, + channel_mult=[1, 2, 2], + channel_mult_emb=2, + num_blocks=2, + attn_resolutions=[attn_res], + gridtype="learnable", + N_grid_channels=N_pos, + lead_time_steps=lt_steps, + lead_time_channels=lt_channels, + use_apex_gn=True, + amp_mode=True, + prob_channels=[1, 3], + ) + return model + + +def setup_model_ddm_plus_plus(img_resolution, lt_steps, lt_channels, seed=0): + """ + Create a model with similar architecture to DDM++. + """ + C_x, N_pos, C_cond = 4, 4, 3 + torch.manual_seed(seed) + model = SongUNetPosLtEmbd( + img_resolution=img_resolution, + in_channels=C_x + N_pos + C_cond + lt_channels, + out_channels=C_x, + lead_time_steps=lt_steps, + lead_time_channels=lt_channels, + use_apex_gn=True, + amp_mode=True, + prob_channels=[1, 3], + ) + return model + + +def setup_model_ncsn_plus_plus(img_resolution, lt_steps, lt_channels, seed=0): + """ + Create a model with similar architecture to NCSN++. + """ + C_x, N_pos, C_cond = 4, 4, 3 + torch.manual_seed(seed) + model = SongUNetPosLtEmbd( + img_resolution=img_resolution, + in_channels=C_x + N_pos + C_cond + lt_channels, + out_channels=C_x, + embedding_type="fourier", + channel_mult_noise=2, + encoder_type="residual", + resample_filter=[1, 3, 3, 1], + lead_time_steps=lt_steps, + lead_time_channels=lt_channels, + use_apex_gn=True, + amp_mode=True, + prob_channels=[1, 3], + ) + return model + + +def generate_data_with_patches(H_p, W_p, device): + """ + Utility function to generate input data with patches in a consistent way + accross multiple tests. + """ + torch.manual_seed(0) + P, B, C_x, C_cond, lt_steps = 4, 3, 4, 3, 4 + max_offset = 35 + input_image = torch.randn([P * B, C_x + C_cond, H_p, W_p]).to(device) + noise_label = torch.randn([P * B]).to(device) + class_label = None + lead_time_label = torch.randint(0, lt_steps, (B,)).to(device) + base_grid = torch.stack( + torch.meshgrid(torch.arange(H_p), torch.arange(W_p), indexing="ij"), dim=0 + )[None].to(device) + offset = torch.randint(0, max_offset, (P, 2))[:, :, None, None].to(device) + global_index = base_grid + offset + return input_image, noise_label, class_label, lead_time_label, global_index + + +def generate_data_no_patches(H, W, device): + """ + Utility function to generate input data without patches in a consistent way + accross multiple tests. + """ + torch.manual_seed(0) + B, C_x, C_cond, lt_steps = 3, 4, 3, 4 + input_image = torch.randn([B, C_x + C_cond, H, W]).to(device) + noise_label = torch.randn([B]).to(device) + class_label = None + lead_time_label = torch.randint(0, lt_steps, (B,)).to(device) + global_index = None + return input_image, noise_label, class_label, lead_time_label, global_index + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_constructor(device): + """ + Test the SongUNetPosLtEmbd constructor for different architectures and shapes. + Also test the shapes of the positional and lead time embeddings. + """ + + # Test DDM++ with square shape + lt_steps, lt_channels = 4, 8 + H = W = 16 + model = ( + setup_model_ddm_plus_plus(H, lt_steps, lt_channels) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert model.pos_embd.shape == (4, H, W) + assert model.lt_embd.shape == (lt_steps, lt_channels, H, W) + + # Test DDM++ with rectangular shape + lt_steps, lt_channels = 4, 8 + H, W = 16, 32 + model = ( + setup_model_ddm_plus_plus([H, W], lt_steps, lt_channels) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert model.pos_embd.shape == (4, H, W) + assert model.lt_embd.shape == (lt_steps, lt_channels, H, W) + + # Test NCSN++ with rectangular shape + lt_steps, lt_channels = 4, 8 + H, W = 16, 32 + model = ( + setup_model_ncsn_plus_plus([H, W], lt_steps, lt_channels) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert model.pos_embd.shape == (4, H, W) + assert model.lt_embd.shape == (lt_steps, lt_channels, H, W) + + # Test corrdiff model with rectangular shape + N_pos, lt_steps, lt_channels = 6, 4, 8 + H, W = 16, 32 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert model.pos_embd.shape == (N_pos, H, W) + assert model.lt_embd.shape == (lt_steps, lt_channels, H, W) + + +# TODO: duplicate tests for model.eval() +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_forward_no_patches(device): + """ + Test the forward method of the SongUNetPosLtEmbd for different architectures + without patches (i.e. input image is the entire global image). Uses AMP, Apex GN, + and compile (for small models only). Also test backward propagation through + the model. + """ + torch._dynamo.reset() + + # Common parameters + B, C_x, lt_steps = 3, 4, 4 + + # DDM++ model with square global shape (no compile because model too large) + H = W = 128 + N_pos, lt_channels = 4, 8 + model = ( + setup_model_ddm_plus_plus(H, lt_steps, lt_channels) + .to(device) + .to(memory_format=torch.channels_last) + ) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(*generate_data_no_patches(H, W, device)) + assert output_image.shape == (B, C_x, H, W) + loss = output_image.sum() + loss.backward() + # TODO: add non-regression test + + # NCSN++ model with square global shape (no compile because model too large) + H = W = 128 + N_pos, lt_channels = 4, 8 + model = ( + setup_model_ncsn_plus_plus(H, lt_steps, lt_channels) + .to(device) + .to(memory_format=torch.channels_last) + ) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(*generate_data_no_patches(H, W, device)) + assert output_image.shape == (B, C_x, H, W) + loss = output_image.sum() + loss.backward() + # TODO: add non-regression test + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_channels = 6, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + # Compile model + model = common.torch_compile_model(model) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(*generate_data_no_patches(H, W, device)) + assert output_image.shape == (B, C_x, H, W) + loss = output_image.sum() + loss.backward() + # TODO: add non-regression test + + return + + +# TODO: duplicate tests for model.eval() mode +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_forward_with_patches(device): + """ + Test the forward method of the SongUNetPosLtEmbd for different architectures + with patches (i.e. only a subset of the global image). Uses AMP, Apex GN, + and compile (for small models only). Also test backward propagation through + the model. + """ + torch._dynamo.reset() + + # Common parameters + P, B, C_x, H_p, W_p, lt_steps = 4, 3, 4, 32, 64, 4 + + # DDM++ model with square global shape (no compile because model too large) + H = W = 128 + N_pos, lt_channels = 4, 8 + model = ( + setup_model_ddm_plus_plus(H, lt_steps, lt_channels) + .to(device) + .to(memory_format=torch.channels_last) + ) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(*generate_data_with_patches(H_p, W_p, device)) + assert output_image.shape == (P * B, C_x, H_p, W_p) + loss = output_image.sum() + loss.backward() + # TODO: add non-regression test + + # NCSN++ model with square global shape (no compile because model too large) + H = W = 128 + N_pos, lt_channels = 4, 8 + model = ( + setup_model_ncsn_plus_plus(H, lt_steps, lt_channels) + .to(device) + .to(memory_format=torch.channels_last) + ) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(*generate_data_with_patches(H_p, W_p, device)) + assert output_image.shape == (P * B, C_x, H_p, W_p) + loss = output_image.sum() + loss.backward() + # TODO: add non-regression test + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_channels = 6, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + # Compile model + model = common.torch_compile_model(model) + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + output_image = model(*generate_data_with_patches(H_p, W_p, device)) + assert output_image.shape == (P * B, C_x, H_p, W_p) + loss = output_image.sum() + loss.backward() + # TODO: add non-regression test + + return + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_positional_embedding_indexing_no_patches(device): + """ + Test for positional_embedding_indexing method. Does not use patches (i.e. + input image is the entire global image). + """ + + # Common parameters + B, lt_steps = 3, 4 + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_channels = 6, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + inputs = generate_data_no_patches(H, W, device) + pos_embed = model.positional_embedding_indexing(inputs[0], inputs[4], inputs[3]) + assert pos_embed.shape == (B, N_pos + lt_channels, H, W) + assert common.validate_tensor_accuracy( + pos_embed, + file_name="songunet_pos_lt_embd_pos_embed_indexing_no_patches_corrdiff.pth", + ) + # TODO: add non-regression tests for other architectures + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_positional_embedding_indexing_with_patches(device): + """ + Test for positional_embedding_indexing method. Uses patches (i.e. input image + is only a subset of the global image). + """ + + # Common parameters + P, B, H_p, W_p, lt_steps = 4, 3, 32, 64, 4 + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_channels = 6, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + inputs = generate_data_with_patches(H_p, W_p, device) + pos_embed = model.positional_embedding_indexing(inputs[0], inputs[4], inputs[3]) + assert pos_embed.shape == (P * B, N_pos + lt_channels, H_p, W_p) + assert common.validate_tensor_accuracy( + pos_embed, + file_name="songunet_pos_lt_embd_pos_embed_indexing_with_patches_corrdiff.pth", + ) + # TODO: add non-regression tests for other architectures + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_optims_no_patches(device): + """Test SongUNetPosLtEmbd optimizations (CUDA graphs, JIT, AMP). Uses input + data without patches (i.e. the entire global image).""" + + # NOTE: for now only test the corrdiff architecture + def setup_model(): + H, W = 128, 112 + N_pos, lt_steps, lt_channels = 6, 4, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + return model, generate_data_no_patches(H, W, device) + + # Ideally always check graphs first + model, invar = setup_model() + assert common.validate_cuda_graphs(model, (*invar,)) + # Check JIT + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_jit(model, (*invar,)) + # Check AMP + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_amp(model, (*invar,)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_optims_with_patches(device): + """Test SongUNetPosLtEmbd optimizations (CUDA graphs, JIT, AMP). Uses input + data with patches (i.e. input image is only a subset of the global image).""" + + # NOTE: for now only test the corrdiff architecture + def setup_model(): + H, W = 128, 112 + H_p, W_p = 32, 64 + N_pos, lt_steps, lt_channels = 6, 4, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + return model, generate_data_with_patches(H_p, W_p, device) + + # Ideally always check graphs first + model, invar = setup_model() + assert common.validate_cuda_graphs(model, (*invar,)) + # Check JIT + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_jit(model, (*invar,)) + # Check AMP + model, invar = setup_model() + with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): + assert common.validate_amp(model, (*invar,)) + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_checkpoint_no_patches(device): + """Test SongUNetPosLtEmbd checkpoint save/load for different + architectures. Uses input data without patches (i.e. input image is the + entire global image).""" + + # Common parameters + lt_steps = 4 + + # DDM++ model with square global shape + H = W = 128 + lt_steps, lt_channels = 4, 8 + model_1 = ( + setup_model_ddm_plus_plus(H, lt_steps, lt_channels, seed=0) + .to(device) + .to(memory_format=torch.channels_last) + ) + model_2 = ( + setup_model_ddm_plus_plus(H, lt_steps, lt_channels, seed=1) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert common.validate_checkpoint( + model_1, + model_2, + generate_data_no_patches(H, W, device), + enable_autocast=True, + ) + + # NCSN++ model with square global shape + H = W = 128 + lt_steps, lt_channels = 4, 8 + model_1 = ( + setup_model_ncsn_plus_plus(H, lt_steps, lt_channels, seed=0) + .to(device) + .to(memory_format=torch.channels_last) + ) + model_2 = ( + setup_model_ncsn_plus_plus(H, lt_steps, lt_channels, seed=1) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert common.validate_checkpoint( + model_1, + model_2, + generate_data_no_patches(H, W, device), + enable_autocast=True, + ) + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_steps, lt_channels = 6, 4, 8 + model_1 = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos, seed=0) + .to(device) + .to(memory_format=torch.channels_last) + ) + model_2 = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos, seed=1) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert common.validate_checkpoint( + model_1, + model_2, + generate_data_no_patches(H, W, device), + enable_autocast=True, + ) + + return + + +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_song_unet_checkpoint_with_patches(device): + """Test SongUNetPosLtEmbd checkpoint save/load for different + architectures. Uses input data with patches (i.e. input image is only a + subset of the global image).""" + + # Common parameters + H_p, W_p, lt_steps = 32, 64, 4 + + # DDM++ model with square global shape + H = W = 128 + lt_steps, lt_channels = 4, 8 + model_1 = ( + setup_model_ddm_plus_plus(H, lt_steps, lt_channels, seed=0) + .to(device) + .to(memory_format=torch.channels_last) + ) + model_2 = ( + setup_model_ddm_plus_plus(H, lt_steps, lt_channels, seed=1) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert common.validate_checkpoint( + model_1, + model_2, + generate_data_with_patches(H_p, W_p, device), + enable_autocast=True, + ) + + # NCSN++ model with square global shape + H = W = 128 + lt_steps, lt_channels = 4, 8 + model_1 = ( + setup_model_ncsn_plus_plus(H, lt_steps, lt_channels, seed=0) + .to(device) + .to(memory_format=torch.channels_last) + ) + model_2 = ( + setup_model_ncsn_plus_plus(H, lt_steps, lt_channels, seed=1) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert common.validate_checkpoint( + model_1, + model_2, + generate_data_with_patches(H_p, W_p, device), + enable_autocast=True, + ) + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_steps, lt_channels = 6, 4, 8 + model_1 = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos, seed=0) + .to(device) + .to(memory_format=torch.channels_last) + ) + model_2 = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos, seed=1) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert common.validate_checkpoint( + model_1, + model_2, + generate_data_with_patches(H_p, W_p, device), + enable_autocast=True, + ) + + return + + +@common.check_ort_version() +@pytest.mark.parametrize("device", ["cuda:0"]) +def test_son_unet_deploy(device): + """Test Song UNet deployment support""" + + # Common parameters + H_p, W_p, lt_steps = 32, 64, 4 + + # CorrDiff model with rectangular global shape + H, W = 128, 112 + N_pos, lt_steps, lt_channels = 6, 4, 8 + model = ( + setup_model_learnable_embd([H, W], lt_steps, lt_channels, N_pos) + .to(device) + .to(memory_format=torch.channels_last) + ) + assert common.validate_onnx_export( + model, + generate_data_with_patches(H_p, W_p, device), + ) + assert common.validate_onnx_runtime( + model, + generate_data_with_patches(H_p, W_p, device), + )