Skip to content

Improve lead time support for diffusion models #980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Improved lead time support for diffusion models
- Improved documentation for diffusion models and diffusion utils.
- Safe API to override `__init__`'s arguments saved in checkpoint file with
`Module.from_checkpoint("chkpt.mdlus", models_args)`.
Expand Down
26 changes: 23 additions & 3 deletions physicsnemo/metrics/diffusion/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,15 @@ def __init__(
self.P_std = P_std
self.sigma_data = sigma_data

def __call__(self, net, images, condition=None, labels=None, augment_pipe=None):
def __call__(
self,
net,
images,
condition=None,
labels=None,
augment_pipe=None,
lead_time_label=None,
):
"""
Calculate and return the loss corresponding to the EDM formulation.

Expand All @@ -245,6 +253,10 @@ def __call__(self, net, images, condition=None, labels=None, augment_pipe=None):
An optional data augmentation function that takes images as input and
returns augmented images. If not provided, no data augmentation is applied.

lead_time_label: torch.Tensor, optional
Lead-time labels to pass to the model, shape (batch_size, 1).
Copy link
Collaborator

@CharlelieLrt CharlelieLrt Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about the shape (batch_size, 1) of lead_time_label? I've seen other places (not in your PR), where it's (batch_size,), some other places where it's (batch_size, 1, 1, 1)

If not provided, the model is called without a lead-time label input.

Returns:
-------
torch.Tensor
Expand All @@ -258,16 +270,24 @@ def __call__(self, net, images, condition=None, labels=None, augment_pipe=None):
augment_pipe(images) if augment_pipe is not None else (images, None)
)
n = torch.randn_like(y) * sigma
additional_labels = {
"augment_labels": augment_labels,
"lead_time_label": lead_time_label,
}
# drop None items to support models that don't have these arguments in `forward`
additional_labels = {
k: v for (k, v) in additional_labels.items() if v is not None
}
if condition is not None:
D_yn = net(
y + n,
sigma,
condition=condition,
class_labels=labels,
augment_labels=augment_labels,
**additional_labels,
)
else:
D_yn = net(y + n, sigma, labels, augment_labels=augment_labels)
D_yn = net(y + n, sigma, labels, **additional_labels)
loss = weight * ((D_yn - y) ** 2)
return loss

Expand Down
106 changes: 73 additions & 33 deletions physicsnemo/models/diffusion/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ class SongUNetPosEmbd(SongUNet):
N_grid_channels : int, optional, default=4
Number of channels :math:`C_{PE}` in the positional embedding grid. For 'sinusoidal' must be 4 or
multiple of 4. For 'linear' and 'test' must be 2. For 'learnable' can be any
value.
value. If 0, positional embedding is disabled (but `lead_time_mode` may still be used).
lead_time_mode : bool, optional, default=False
Provided for convenience. It is recommended to use the architecture
:class:`~physicsnemo.models.diffusion.song_unet.SongUNetPosLtEmbd`
Expand Down Expand Up @@ -794,7 +794,7 @@ def __init__(
profile_mode: bool = False,
amp_mode: bool = False,
lead_time_mode: bool = False,
lead_time_channels: int = None,
lead_time_channels: int | None = None,
lead_time_steps: int = 9,
prob_channels: List[int] = [],
):
Expand Down Expand Up @@ -826,12 +826,16 @@ def __init__(

self.gridtype = gridtype
self.N_grid_channels = N_grid_channels
if self.gridtype == "learnable":
if (self.gridtype == "learnable") or (self.N_grid_channels == 0):
self.pos_embd = self._get_positional_embedding()
else:
self.register_buffer("pos_embd", self._get_positional_embedding().float())
self.lead_time_mode = lead_time_mode
if self.lead_time_mode:
if (lead_time_channels is None) or (lead_time_channels <= 0):
raise ValueError(
"`lead_time_channels` must be >= 1 if `lead_time_mode` is enabled."
)
Comment on lines +835 to +838
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with this validation. However, there is clearly redundancy between lead_time_mode and lead_time_channels. In addition, for positional embeddings, we can disable them by setting N_grid_channels = 0 (there is no boolean parameter positional_embedding_mode), so we expect the same mechanism to work for lead-time embeddings (i.e. disable by setting lead_time_channels = 0).

If your concern is about backward compatibility if removing lead_time_mode:

  1. This parameters was introduced only very recently (<2 months), so it could be okay to remove it
  2. If still a concern, can we keep lead_time_mode but add a deprecation warning in this if statement:
if self.lead_time_mode:
    warnings.warn(
        f"The parameter `lead_time_mode` will be deprecated in a future version. "
        f"The recommended way to enable (disable) lead-time embeddings is to set `lead_time_channels > 0` (`lead_time_channels = 0`)")
    [...]

self.lead_time_channels = lead_time_channels
self.lead_time_steps = lead_time_steps
self.lt_embd = self._get_lead_time_embedding()
Expand All @@ -840,6 +844,12 @@ def __init__(
self.scalar = torch.nn.Parameter(
torch.ones((1, len(self.prob_channels), 1, 1))
)
else:
if lead_time_channels:
raise ValueError(
"When `lead_time_mode` is disabled, `lead_time_channels` may not be set."
)
self.lt_embd = None

def forward(
self,
Expand All @@ -861,11 +871,11 @@ def forward(
"Cannot provide both embedding_selector and global_index."
)

if x.dtype != self.pos_embd.dtype:
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype):
self.pos_embd = self.pos_embd.to(x.dtype)

# Append positional embedding to input conditioning
if self.pos_embd is not None:
if (self.pos_embd is not None) or (self.lt_embd is not None):
# Select positional embeddings with a selector function
if embedding_selector is not None:
selected_pos_embd = self.positional_embedding_selector(
Expand Down Expand Up @@ -905,8 +915,13 @@ def positional_embedding_indexing(
r"""Select positional embeddings using global indices.

This method uses global indices to select specific subset of the
positional embedding grid (called *patches*). If no indices are provided,
the entire positional embedding grid is returned.
positional embedding grid and/or the lead-time embedding grid (called
*patches*). If no indices are provided, the entire embedding grid is returned.
The positional embedding grid is returned if `N_grid_channels > 0`, while
the lead-time embedding grid is returned if `lead_time_mode == True`. If
both positional and lead-time embedding are enabled, both are returned
(concatenated). If neither is enabled, this function should not be called;
doing so will raise a ValueError.

Parameters
----------
Expand All @@ -918,15 +933,24 @@ def positional_embedding_indexing(
the patches to extract from the positional embedding grid.
:math:`P` is the number of distinct patches in the input tensor ``x``.
The channel dimension should contain :math:`j`, :math:`i` indices that
should represent the indices of the pixels to extract from the embedding grid.
should represent the indices of the pixels to extract from the
embedding grid.
lead_time_label : Optional[torch.Tensor], default=None
Tensor of shape :math:`(P,)` that corresponds to the lead-time label for each patch.
Only used if ``lead_time_mode`` is True.

Returns
-------
torch.Tensor
Selected positional embeddings with shape :math:`(P \times B, C_{PE}, H_{in}, W_{in})`
(same spatial resolution as ``global_index``) if ``global_index`` is provided.
If ``global_index`` is None, the entire positional embedding grid
is duplicated :math:`B` times and returned with shape :math:`(B, C_{PE}, H, W)`.
Selected embeddings with shape :math:`(P \times B, C_{PE} [+
C_{LT}], H_{in}, W_{in})`. :math:`C_{PE}` is the number of
embedding channels in the positional embedding grid, and
:math:`C_{LT}` is the number of embedding channels in the lead-time
embedding grid. If ``lead_time_label`` is provided, the lead-time
embedding channels are included. If ``global_index`` is `None`,
:math:`P = 1` is assumed, and the positional embedding grid is
duplicated :math:`B` times and returned with shape
:math:`(B, C_{PE} [+ C_{LT}], H, W)`.

Example
-------
Expand All @@ -951,7 +975,7 @@ def positional_embedding_indexing(
"""
# If no global indices are provided, select all embeddings and expand
# to match the batch size of the input
if x.dtype != self.pos_embd.dtype:
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype):
self.pos_embd = self.pos_embd.to(x.dtype)

if global_index is None:
Expand Down Expand Up @@ -989,23 +1013,26 @@ def positional_embedding_indexing(
global_index = torch.reshape(
torch.permute(global_index, (1, 0, 2, 3)), (2, -1)
) # (P, 2, X, Y) to (2, P*X*Y)
selected_pos_embd = self.pos_embd[
:, global_index[0], global_index[1]
] # (C_pe, P*X*Y)
selected_pos_embd = torch.permute(
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)),
(1, 0, 2, 3),
) # (P, C_pe, X, Y)

selected_pos_embd = selected_pos_embd.repeat(
B, 1, 1, 1
) # (B*P, C_pe, X, Y)

if self.pos_embd is not None:
selected_pos_embd = self.pos_embd[
:, global_index[0], global_index[1]
] # (C_pe, P*X*Y)
selected_pos_embd = torch.permute(
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)),
(1, 0, 2, 3),
) # (P, C_pe, X, Y)

selected_pos_embd = selected_pos_embd.repeat(
B, 1, 1, 1
) # (B*P, C_pe, X, Y)

embeds = [selected_pos_embd]
else:
embeds = []

# Append positional and lead time embeddings to input conditioning
if self.lead_time_mode:
embeds = []
if self.pos_embd is not None:
embeds.append(selected_pos_embd) # reuse code below
if self.lt_embd is not None:
lt_embds = self.lt_embd[
lead_time_label.int()
Expand All @@ -1026,8 +1053,12 @@ def positional_embedding_indexing(
) # (B*P, C_pe, X, Y)
embeds.append(selected_lt_pos_embd)

if len(embeds) > 0:
selected_pos_embd = torch.cat(embeds, dim=1)
if len(embeds) > 0:
selected_pos_embd = torch.cat(embeds, dim=1)
else:
raise ValueError(
"`positional_embedding_indexing` should not be called when neither lead-time nor positional embeddings are used."
)

return selected_pos_embd

Expand Down Expand Up @@ -1090,15 +1121,24 @@ def positional_embedding_selector(
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1))
>>>
"""
if x.dtype != self.pos_embd.dtype:
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype):
self.pos_embd = self.pos_embd.to(x.dtype)
if lead_time_label is not None:

if self.pos_embd is not None and lead_time_label is not None:
# both positional and lead-time embedding
# all patches share same lead_time_label
embeddings = torch.cat(
[self.pos_embd, self.lt_embd[lead_time_label[0].int()]]
)
else:
elif self.pos_embd is None: # positional embedding only
embeddings = self.pos_embd
elif lead_time_label is None: # lead time embedding only
embeddings = self.lt_embd[lead_time_label[0].int()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using only lead_time_label[0] here? I know it follows the original implementation, but isn't it an error?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of [0] does seem like an error to me...wouldn't that be either selecting only the first lead time in the batch or lead_time dimension (whichever is first in that tensor)? I see it was introduced in #913, maybe @tge25 can chime in?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, I think there is a second bug here. The elif clauses should read elif self.pos_embd is not None and elif self.lead_time_label is not None, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed with @tge25 that the [0] is some prototyping code, so we should get rid of it.

This elif self.lead_time_label is not None is not possible, because lead_time_label is not an attribute. But I agree that the elif clause should be different. I think it should be elif self.lt_embd is not None, right?

else:
raise ValueError(
"`positional_embedding_selector` should not be called when neither lead-time nor positional embeddings are used."
)

return embedding_selector(embeddings) # (B, N_pe, H, W)

def _get_positional_embedding(self):
Expand Down Expand Up @@ -1331,7 +1371,7 @@ def __init__(
resample_filter: List[int] = [1, 1],
gridtype: str = "sinusoidal",
N_grid_channels: int = 4,
lead_time_channels: int = None,
lead_time_channels: int | None = None,
lead_time_steps: int = 9,
prob_channels: List[int] = [],
checkpoint_level: int = 0,
Expand Down
26 changes: 20 additions & 6 deletions physicsnemo/utils/diffusion/deterministic_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def deterministic_sampler(
S_min: float = 0.0,
S_max: float = float("inf"),
S_noise: float = 1.0,
lead_time_label: torch.Tensor | None = None,
) -> torch.Tensor:
r"""
Generalized sampler, representing the superset of all sampling methods
Expand Down Expand Up @@ -157,6 +158,9 @@ def deterministic_sampler(
stochatsic sampler. Added signal noise is proportinal to
:math:`\epsilon_i` where :math:`\epsilon_i \sim \mathcal{N}(0, S_{noise}^2)`. Defaults
to 1.0.
lead_time_label: torch.Tensor, optional
Lead-time labels to pass to the model, shape (batch_size, 1).
If not provided, the model is called without a lead-time label input.

Returns
-------
Expand All @@ -167,6 +171,11 @@ def deterministic_sampler(
# conditioning
x_lr = img_lr

# do not pass lead time labels to nets that may not support them
additional_labels = (
{} if lead_time_label is None else {"lead_time_label": lead_time_label}
)

if solver not in ["euler", "heun"]:
raise ValueError(f"Unknown solver {solver}")
if discretization not in ["vp", "ve", "iddpm", "edm"]:
Expand Down Expand Up @@ -198,8 +207,7 @@ def deterministic_sampler(
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
ve_sigma_inv = lambda sigma: sigma**2

# Select default noise level range based on the specified
# time step discretization.
# Select default noise level range based on the specified time step discretization.
if sigma_min is None:
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
sigma_min = {"vp": vp_def, "ve": 0.02, "iddpm": 0.002, "edm": 0.002}[
Expand Down Expand Up @@ -304,11 +312,12 @@ def deterministic_sampler(
sigma(t_hat),
condition=x_lr,
class_labels=class_labels,
**additional_labels,
).to(torch.float64)
else:
denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to(
torch.float64
)
denoised = net(
x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels, **additional_labels
).to(torch.float64)
d_cur = (
sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)
) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
Expand All @@ -326,10 +335,15 @@ def deterministic_sampler(
sigma(t_prime),
condition=x_lr,
class_labels=class_labels,
**additional_labels,
).to(torch.float64)
else:
denoised = net(
x_prime / s(t_prime), x_lr, sigma(t_prime), class_labels
x_prime / s(t_prime),
x_lr,
sigma(t_prime),
class_labels,
**additional_labels,
).to(torch.float64)
d_prime = (
sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)
Expand Down
21 changes: 21 additions & 0 deletions test/metrics/diffusion/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ def fake_condition_net(y, sigma, condition, class_labels=None, augment_labels=No
def fake_net(y, sigma, labels, augment_labels=None):
return torch.tensor([1.0])

def fake_condition_net_lt(
y,
sigma,
condition,
class_labels=None,
augment_labels=None,
lead_time_label=None,
):
assert lead_time_label is not None # test that this is properly passed through
return torch.tensor([1.0])

loss_func = EDMLoss()

img = torch.tensor([[[[1.0]]]])
Expand All @@ -160,6 +171,16 @@ def mock_augment_pipe(imgs):
loss_value_with_augmentation = loss_func(fake_net, img, labels, mock_augment_pipe)
assert isinstance(loss_value_with_augmentation, torch.Tensor)

lead_time_label = torch.tensor([1])
loss_value = loss_func(
fake_condition_net_lt,
img,
condition=condition,
labels=labels,
lead_time_label=lead_time_label,
)
assert isinstance(loss_value, torch.Tensor)


# RegressionLoss tests

Expand Down
Loading