Skip to content

Improve lead time support for diffusion models #980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

jleinonen
Copy link
Collaborator

PhysicsNeMo Pull Request

Description

Adds/fixes support for lead-time labels in various places where it was missing or not working:

  • SongUNetPosEmbd now works properly using either, both or neither of positional embedding and lead-time embedding. In the previous version some pieces of code could try to access properties of these even when set to None.
  • deterministic_sampler now accepts lead-time labels and passes them through to the model, if given.
  • EDMLoss also now supports lead-time labels.
  • Added tests for the above features.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

No new dependencies needed.

@jleinonen jleinonen self-assigned this Jun 17, 2025
@jleinonen
Copy link
Collaborator Author

/blossom-ci

@jleinonen
Copy link
Collaborator Author

/blossom-ci

@jleinonen
Copy link
Collaborator Author

/blossom-ci

@jleinonen
Copy link
Collaborator Author

/blossom-ci

@jleinonen
Copy link
Collaborator Author

/blossom-ci

Copy link
Collaborator

@CharlelieLrt CharlelieLrt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All changes proposed look good to me. Just a few details would require some improvements:

  1. As far as I understand, this PR decouples lead-time embeddings from positional embeddings, in order to allow more flexibility in using them independently from each other. This new functionality does not seem to be used in any of the training recipes/examples. It could be useful to detail in the PR description the broader context (e.g. which applications is it going to be applied to? Will there be a follow-up PR? etc...)

  2. The new flexibility to independently use lead-time and positional embeddings should be clearly explained in the docstrings.

  3. IMO the current implementation of the lead-time embeddings has too many failure modes to be safely exposed to broader applications. For example, in positional_embedding_indexing:

  • lead_time_label can be done while self.lt_embd is not None, which leads to an error
  • Conversely, lead_time_label could be a user-provided tensor, while self.lt_embd is None, which leads to lead_time_label being silently ignored.

I strongly support better parameters validation to eliminate this failure modes, either in the forward method or the __Init__ when possible.

@jleinonen
Copy link
Collaborator Author

Hi @CharlelieLrt,

  1. As far as I understand, this PR decouples lead-time embeddings from positional embeddings, in order to allow more flexibility in using them independently from each other. This new functionality does not seem to be used in any of the training recipes/examples. It could be useful to detail in the PR description the broader context (e.g. which applications is it going to be applied to? Will there be a follow-up PR? etc...)

I would say lead-time embeddings were already decoupled from positional embeddings before the PR. This PR just includes some fixes to make sure that they can be enabled when positional embeddings are disabled, or vice versa.

  1. The new flexibility to independently use lead-time and positional embeddings should be clearly explained in the docstrings.

As they were already implemented independently, I don't think it's a new flexibility, but I can improve the docstrings in that regard.

  1. IMO the current implementation of the lead-time embeddings has too many failure modes to be safely exposed to broader applications. For example, in positional_embedding_indexing:
  • lead_time_label can be done while self.lt_embd is not None, which leads to an error
  • Conversely, lead_time_label could be a user-provided tensor, while self.lt_embd is None, which leads to lead_time_label being silently ignored.

I strongly support better parameters validation to eliminate this failure modes, either in the forward method or the __Init__ when possible.

I'll add some checks to make sure the inputs conform with the model configuration (but note that as far as I understand, these failure modes already existed before the PR).

@jleinonen
Copy link
Collaborator Author

/blossom-ci

Comment on lines +835 to +838
if (lead_time_channels is None) or (lead_time_channels <= 0):
raise ValueError(
"`lead_time_channels` must be >= 1 if `lead_time_mode` is enabled."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with this validation. However, there is clearly redundancy between lead_time_mode and lead_time_channels. In addition, for positional embeddings, we can disable them by setting N_grid_channels = 0 (there is no boolean parameter positional_embedding_mode), so we expect the same mechanism to work for lead-time embeddings (i.e. disable by setting lead_time_channels = 0).

If your concern is about backward compatibility if removing lead_time_mode:

  1. This parameters was introduced only very recently (<2 months), so it could be okay to remove it
  2. If still a concern, can we keep lead_time_mode but add a deprecation warning in this if statement:
if self.lead_time_mode:
    warnings.warn(
        f"The parameter `lead_time_mode` will be deprecated in a future version. "
        f"The recommended way to enable (disable) lead-time embeddings is to set `lead_time_channels > 0` (`lead_time_channels = 0`)")
    [...]

embeddings = self.pos_embd
elif lead_time_label is None: # lead time embedding only
embeddings = self.lt_embd[lead_time_label[0].int()]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why using only lead_time_label[0] here? I know it follows the original implementation, but isn't it an error?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of [0] does seem like an error to me...wouldn't that be either selecting only the first lead time in the batch or lead_time dimension (whichever is first in that tensor)? I see it was introduced in #913, maybe @tge25 can chime in?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, I think there is a second bug here. The elif clauses should read elif self.pos_embd is not None and elif self.lead_time_label is not None, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirmed with @tge25 that the [0] is some prototyping code, so we should get rid of it.

This elif self.lead_time_label is not None is not possible, because lead_time_label is not an attribute. But I agree that the elif clause should be different. I think it should be elif self.lt_embd is not None, right?

@@ -245,6 +253,10 @@ def __call__(self, net, images, condition=None, labels=None, augment_pipe=None):
An optional data augmentation function that takes images as input and
returns augmented images. If not provided, no data augmentation is applied.

lead_time_label: torch.Tensor, optional
Lead-time labels to pass to the model, shape (batch_size, 1).
Copy link
Collaborator

@CharlelieLrt CharlelieLrt Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about the shape (batch_size, 1) of lead_time_label? I've seen other places (not in your PR), where it's (batch_size,), some other places where it's (batch_size, 1, 1, 1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants