diff --git a/CHANGELOG.md b/CHANGELOG.md index b583742985..e5d73a26f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Improved lead time support for diffusion models - Improved documentation for diffusion models and diffusion utils. - Safe API to override `__init__`'s arguments saved in checkpoint file with `Module.from_checkpoint("chkpt.mdlus", models_args)`. diff --git a/physicsnemo/metrics/diffusion/loss.py b/physicsnemo/metrics/diffusion/loss.py index 83708b4cb1..1967504e01 100644 --- a/physicsnemo/metrics/diffusion/loss.py +++ b/physicsnemo/metrics/diffusion/loss.py @@ -220,7 +220,15 @@ def __init__( self.P_std = P_std self.sigma_data = sigma_data - def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): + def __call__( + self, + net, + images, + condition=None, + labels=None, + augment_pipe=None, + lead_time_label=None, + ): """ Calculate and return the loss corresponding to the EDM formulation. @@ -245,6 +253,10 @@ def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): An optional data augmentation function that takes images as input and returns augmented images. If not provided, no data augmentation is applied. + lead_time_label: torch.Tensor, optional + Lead-time labels to pass to the model, shape (batch_size, 1). + If not provided, the model is called without a lead-time label input. + Returns: ------- torch.Tensor @@ -258,16 +270,24 @@ def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): augment_pipe(images) if augment_pipe is not None else (images, None) ) n = torch.randn_like(y) * sigma + additional_labels = { + "augment_labels": augment_labels, + "lead_time_label": lead_time_label, + } + # drop None items to support models that don't have these arguments in `forward` + additional_labels = { + k: v for (k, v) in additional_labels.items() if v is not None + } if condition is not None: D_yn = net( y + n, sigma, condition=condition, class_labels=labels, - augment_labels=augment_labels, + **additional_labels, ) else: - D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) + D_yn = net(y + n, sigma, labels, **additional_labels) loss = weight * ((D_yn - y) ** 2) return loss diff --git a/physicsnemo/models/diffusion/song_unet.py b/physicsnemo/models/diffusion/song_unet.py index 7e48f60bf1..c9a17838eb 100644 --- a/physicsnemo/models/diffusion/song_unet.py +++ b/physicsnemo/models/diffusion/song_unet.py @@ -666,7 +666,7 @@ class SongUNetPosEmbd(SongUNet): N_grid_channels : int, optional, default=4 Number of channels :math:`C_{PE}` in the positional embedding grid. For 'sinusoidal' must be 4 or multiple of 4. For 'linear' and 'test' must be 2. For 'learnable' can be any - value. + value. If 0, positional embedding is disabled (but `lead_time_mode` may still be used). lead_time_mode : bool, optional, default=False Provided for convenience. It is recommended to use the architecture :class:`~physicsnemo.models.diffusion.song_unet.SongUNetPosLtEmbd` @@ -794,7 +794,7 @@ def __init__( profile_mode: bool = False, amp_mode: bool = False, lead_time_mode: bool = False, - lead_time_channels: int = None, + lead_time_channels: int | None = None, lead_time_steps: int = 9, prob_channels: List[int] = [], ): @@ -826,12 +826,16 @@ def __init__( self.gridtype = gridtype self.N_grid_channels = N_grid_channels - if self.gridtype == "learnable": + if (self.gridtype == "learnable") or (self.N_grid_channels == 0): self.pos_embd = self._get_positional_embedding() else: self.register_buffer("pos_embd", self._get_positional_embedding().float()) self.lead_time_mode = lead_time_mode if self.lead_time_mode: + if (lead_time_channels is None) or (lead_time_channels <= 0): + raise ValueError( + "`lead_time_channels` must be >= 1 if `lead_time_mode` is enabled." + ) self.lead_time_channels = lead_time_channels self.lead_time_steps = lead_time_steps self.lt_embd = self._get_lead_time_embedding() @@ -840,6 +844,12 @@ def __init__( self.scalar = torch.nn.Parameter( torch.ones((1, len(self.prob_channels), 1, 1)) ) + else: + if lead_time_channels: + raise ValueError( + "When `lead_time_mode` is disabled, `lead_time_channels` may not be set." + ) + self.lt_embd = None def forward( self, @@ -861,11 +871,11 @@ def forward( "Cannot provide both embedding_selector and global_index." ) - if x.dtype != self.pos_embd.dtype: + if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype): self.pos_embd = self.pos_embd.to(x.dtype) # Append positional embedding to input conditioning - if self.pos_embd is not None: + if (self.pos_embd is not None) or (self.lt_embd is not None): # Select positional embeddings with a selector function if embedding_selector is not None: selected_pos_embd = self.positional_embedding_selector( @@ -905,8 +915,13 @@ def positional_embedding_indexing( r"""Select positional embeddings using global indices. This method uses global indices to select specific subset of the - positional embedding grid (called *patches*). If no indices are provided, - the entire positional embedding grid is returned. + positional embedding grid and/or the lead-time embedding grid (called + *patches*). If no indices are provided, the entire embedding grid is returned. + The positional embedding grid is returned if `N_grid_channels > 0`, while + the lead-time embedding grid is returned if `lead_time_mode == True`. If + both positional and lead-time embedding are enabled, both are returned + (concatenated). If neither is enabled, this function should not be called; + doing so will raise a ValueError. Parameters ---------- @@ -918,15 +933,24 @@ def positional_embedding_indexing( the patches to extract from the positional embedding grid. :math:`P` is the number of distinct patches in the input tensor ``x``. The channel dimension should contain :math:`j`, :math:`i` indices that - should represent the indices of the pixels to extract from the embedding grid. + should represent the indices of the pixels to extract from the + embedding grid. + lead_time_label : Optional[torch.Tensor], default=None + Tensor of shape :math:`(P,)` that corresponds to the lead-time label for each patch. + Only used if ``lead_time_mode`` is True. Returns ------- torch.Tensor - Selected positional embeddings with shape :math:`(P \times B, C_{PE}, H_{in}, W_{in})` - (same spatial resolution as ``global_index``) if ``global_index`` is provided. - If ``global_index`` is None, the entire positional embedding grid - is duplicated :math:`B` times and returned with shape :math:`(B, C_{PE}, H, W)`. + Selected embeddings with shape :math:`(P \times B, C_{PE} [+ + C_{LT}], H_{in}, W_{in})`. :math:`C_{PE}` is the number of + embedding channels in the positional embedding grid, and + :math:`C_{LT}` is the number of embedding channels in the lead-time + embedding grid. If ``lead_time_label`` is provided, the lead-time + embedding channels are included. If ``global_index`` is `None`, + :math:`P = 1` is assumed, and the positional embedding grid is + duplicated :math:`B` times and returned with shape + :math:`(B, C_{PE} [+ C_{LT}], H, W)`. Example ------- @@ -951,7 +975,7 @@ def positional_embedding_indexing( """ # If no global indices are provided, select all embeddings and expand # to match the batch size of the input - if x.dtype != self.pos_embd.dtype: + if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype): self.pos_embd = self.pos_embd.to(x.dtype) if global_index is None: @@ -989,23 +1013,26 @@ def positional_embedding_indexing( global_index = torch.reshape( torch.permute(global_index, (1, 0, 2, 3)), (2, -1) ) # (P, 2, X, Y) to (2, P*X*Y) - selected_pos_embd = self.pos_embd[ - :, global_index[0], global_index[1] - ] # (C_pe, P*X*Y) - selected_pos_embd = torch.permute( - torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), - (1, 0, 2, 3), - ) # (P, C_pe, X, Y) - - selected_pos_embd = selected_pos_embd.repeat( - B, 1, 1, 1 - ) # (B*P, C_pe, X, Y) + + if self.pos_embd is not None: + selected_pos_embd = self.pos_embd[ + :, global_index[0], global_index[1] + ] # (C_pe, P*X*Y) + selected_pos_embd = torch.permute( + torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), + (1, 0, 2, 3), + ) # (P, C_pe, X, Y) + + selected_pos_embd = selected_pos_embd.repeat( + B, 1, 1, 1 + ) # (B*P, C_pe, X, Y) + + embeds = [selected_pos_embd] + else: + embeds = [] # Append positional and lead time embeddings to input conditioning if self.lead_time_mode: - embeds = [] - if self.pos_embd is not None: - embeds.append(selected_pos_embd) # reuse code below if self.lt_embd is not None: lt_embds = self.lt_embd[ lead_time_label.int() @@ -1026,8 +1053,12 @@ def positional_embedding_indexing( ) # (B*P, C_pe, X, Y) embeds.append(selected_lt_pos_embd) - if len(embeds) > 0: - selected_pos_embd = torch.cat(embeds, dim=1) + if len(embeds) > 0: + selected_pos_embd = torch.cat(embeds, dim=1) + else: + raise ValueError( + "`positional_embedding_indexing` should not be called when neither lead-time nor positional embeddings are used." + ) return selected_pos_embd @@ -1090,15 +1121,24 @@ def positional_embedding_selector( ... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) >>> """ - if x.dtype != self.pos_embd.dtype: + if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype): self.pos_embd = self.pos_embd.to(x.dtype) - if lead_time_label is not None: + + if self.pos_embd is not None and lead_time_label is not None: + # both positional and lead-time embedding # all patches share same lead_time_label embeddings = torch.cat( [self.pos_embd, self.lt_embd[lead_time_label[0].int()]] ) - else: + elif self.pos_embd is None: # positional embedding only embeddings = self.pos_embd + elif lead_time_label is None: # lead time embedding only + embeddings = self.lt_embd[lead_time_label[0].int()] + else: + raise ValueError( + "`positional_embedding_selector` should not be called when neither lead-time nor positional embeddings are used." + ) + return embedding_selector(embeddings) # (B, N_pe, H, W) def _get_positional_embedding(self): @@ -1331,7 +1371,7 @@ def __init__( resample_filter: List[int] = [1, 1], gridtype: str = "sinusoidal", N_grid_channels: int = 4, - lead_time_channels: int = None, + lead_time_channels: int | None = None, lead_time_steps: int = 9, prob_channels: List[int] = [], checkpoint_level: int = 0, diff --git a/physicsnemo/utils/diffusion/deterministic_sampler.py b/physicsnemo/utils/diffusion/deterministic_sampler.py index 6dce486abc..36f275d5ae 100644 --- a/physicsnemo/utils/diffusion/deterministic_sampler.py +++ b/physicsnemo/utils/diffusion/deterministic_sampler.py @@ -49,6 +49,7 @@ def deterministic_sampler( S_min: float = 0.0, S_max: float = float("inf"), S_noise: float = 1.0, + lead_time_label: torch.Tensor | None = None, ) -> torch.Tensor: r""" Generalized sampler, representing the superset of all sampling methods @@ -157,6 +158,9 @@ def deterministic_sampler( stochatsic sampler. Added signal noise is proportinal to :math:`\epsilon_i` where :math:`\epsilon_i \sim \mathcal{N}(0, S_{noise}^2)`. Defaults to 1.0. + lead_time_label: torch.Tensor, optional + Lead-time labels to pass to the model, shape (batch_size, 1). + If not provided, the model is called without a lead-time label input. Returns ------- @@ -167,6 +171,11 @@ def deterministic_sampler( # conditioning x_lr = img_lr + # do not pass lead time labels to nets that may not support them + additional_labels = ( + {} if lead_time_label is None else {"lead_time_label": lead_time_label} + ) + if solver not in ["euler", "heun"]: raise ValueError(f"Unknown solver {solver}") if discretization not in ["vp", "ve", "iddpm", "edm"]: @@ -198,8 +207,7 @@ def deterministic_sampler( ve_sigma_deriv = lambda t: 0.5 / t.sqrt() ve_sigma_inv = lambda sigma: sigma**2 - # Select default noise level range based on the specified - # time step discretization. + # Select default noise level range based on the specified time step discretization. if sigma_min is None: vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s) sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[ @@ -304,11 +312,12 @@ def deterministic_sampler( sigma(t_hat), condition=x_lr, class_labels=class_labels, + **additional_labels, ).to(torch.float64) else: - denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to( - torch.float64 - ) + denoised = net( + x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels, **additional_labels + ).to(torch.float64) d_cur = ( sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat) ) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised @@ -326,10 +335,15 @@ def deterministic_sampler( sigma(t_prime), condition=x_lr, class_labels=class_labels, + **additional_labels, ).to(torch.float64) else: denoised = net( - x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels + x_prime / s(t_prime), + x_lr, + sigma(t_prime), + class_labels, + **additional_labels, ).to(torch.float64) d_prime = ( sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime) diff --git a/test/metrics/diffusion/test_losses.py b/test/metrics/diffusion/test_losses.py index 6e940e3dbc..5eea50dcff 100644 --- a/test/metrics/diffusion/test_losses.py +++ b/test/metrics/diffusion/test_losses.py @@ -139,6 +139,17 @@ def fake_condition_net(y, sigma, condition, class_labels=None, augment_labels=No def fake_net(y, sigma, labels, augment_labels=None): return torch.tensor([1.0]) + def fake_condition_net_lt( + y, + sigma, + condition, + class_labels=None, + augment_labels=None, + lead_time_label=None, + ): + assert lead_time_label is not None # test that this is properly passed through + return torch.tensor([1.0]) + loss_func = EDMLoss() img = torch.tensor([[[[1.0]]]]) @@ -160,6 +171,16 @@ def mock_augment_pipe(imgs): loss_value_with_augmentation = loss_func(fake_net, img, labels, mock_augment_pipe) assert isinstance(loss_value_with_augmentation, torch.Tensor) + lead_time_label = torch.tensor([1]) + loss_value = loss_func( + fake_condition_net_lt, + img, + condition=condition, + labels=labels, + lead_time_label=lead_time_label, + ) + assert isinstance(loss_value, torch.Tensor) + # RegressionLoss tests diff --git a/test/models/diffusion/test_song_unet_pos_embd.py b/test/models/diffusion/test_song_unet_pos_embd.py index dfcaf865f2..c9dc6f0dfe 100644 --- a/test/models/diffusion/test_song_unet_pos_embd.py +++ b/test/models/diffusion/test_song_unet_pos_embd.py @@ -328,3 +328,53 @@ def test_son_unet_deploy(device): assert common.validate_onnx_runtime( model, (*[input_image, noise_labels, class_labels],) ) + + +@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) +@pytest.mark.parametrize("lead_time_mode", [False, True]) +@pytest.mark.parametrize("N_grid_channels", [0, 4]) +@pytest.mark.parametrize("lead_time_channels", [0, 2]) +def test_song_unet_positional_leadtime( + device, lead_time_mode, N_grid_channels, lead_time_channels +): + """Test that both positional and lead-time embeddings can be used independently""" + + img_resolution = 16 + out_channels = 2 + lead_time_steps = 2 + in_channels = 2 + N_grid_channels + (lead_time_channels if lead_time_mode else 0) + + def _create_model(): + return UNet( + img_resolution=img_resolution, + in_channels=in_channels, + out_channels=out_channels, + N_grid_channels=N_grid_channels, + lead_time_mode=lead_time_mode, + lead_time_channels=lead_time_channels, + lead_time_steps=lead_time_steps, + ).to(device) + + if (lead_time_channels > 0) != lead_time_mode: + with pytest.raises(ValueError): + model = _create_model() + return + else: + model = _create_model() + + noise_labels = torch.randn([2]).to(device) + class_labels = torch.randint(0, 1, (2, 1)).to(device) + input_image = torch.ones([2, 2, 16, 16]).to(device) + lead_time_label = torch.as_tensor([0, 1]).to(device) + + assert bool(N_grid_channels) == (model.pos_embd is not None) + assert lead_time_mode == (hasattr(model, "lt_embd") and (model.lt_embd is not None)) + + if lead_time_mode: + output_image = model( + input_image, noise_labels, class_labels, lead_time_label=lead_time_label + ) + else: + output_image = model(input_image, noise_labels, class_labels) + + assert output_image.shape == (2, out_channels, img_resolution, img_resolution) diff --git a/test/utils/generative/test_deterministic_sampler.py b/test/utils/generative/test_deterministic_sampler.py index 2aa89489ca..85f6ed14fb 100644 --- a/test/utils/generative/test_deterministic_sampler.py +++ b/test/utils/generative/test_deterministic_sampler.py @@ -33,12 +33,23 @@ def round_sigma(self, sigma): return torch.tensor(sigma) +# Version that supports lead time labels +class MockNetLt(MockNet): + def __call__(self, x, img_lr, sigma, class_labels, lead_time_label=None): + return x + + # Define a fixture for the network @pytest.fixture def mock_net(): return MockNet() +@pytest.fixture +def mock_net_lt(): + return MockNetLt() + + # Basic functionality test @import_or_fail("cftime") def test_deterministic_sampler_output_type_and_shape(mock_net, pytestconfig): @@ -179,6 +190,22 @@ def test_deterministic_sampler_scaling_validation(mock_net, scaling, pytestconfi assert isinstance(output, torch.Tensor) +# Test support for lead time labels +@import_or_fail("cftime") +def test_deterministic_sampler_lead_time(mock_net_lt, pytestconfig): + + from physicsnemo.utils.diffusion import deterministic_sampler + + latents = torch.randn(1, 3, 64, 64) + img_lr = torch.randn(1, 3, 64, 64) + lt_label = torch.randint(0, 10, (1, 1)) + + output = deterministic_sampler( + net=mock_net_lt, latents=latents, img_lr=img_lr, lead_time_label=lt_label + ) + assert isinstance(output, torch.Tensor) + + # Test correctness with known ODE solution @import_or_fail("cftime") def test_deterministic_sampler_correctness(pytestconfig):