-
Notifications
You must be signed in to change notification settings - Fork 395
Improve lead time support for diffusion models #980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
bbaea98
232c29c
8b94e59
f3338b6
7ba0563
11e4ea5
6f08b02
63b70c9
fd35097
ef8a9a6
d0b1bfb
4d503c4
77a5cde
e600803
4929596
457d50d
522440b
fd96890
29ca853
58974a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -220,7 +220,15 @@ def __init__( | |
self.P_std = P_std | ||
self.sigma_data = sigma_data | ||
|
||
def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): | ||
def __call__( | ||
self, | ||
net, | ||
images, | ||
condition=None, | ||
labels=None, | ||
augment_pipe=None, | ||
lead_time_label=None, | ||
): | ||
""" | ||
Calculate and return the loss corresponding to the EDM formulation. | ||
|
||
|
@@ -245,6 +253,10 @@ def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): | |
An optional data augmentation function that takes images as input and | ||
returns augmented images. If not provided, no data augmentation is applied. | ||
|
||
lead_time_label: torch.Tensor, optional | ||
Lead-time labels to pass to the model, shape (batch_size, 1). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you sure about the shape |
||
If not provided, the model is called without a lead-time label input. | ||
|
||
Returns: | ||
------- | ||
torch.Tensor | ||
|
@@ -258,16 +270,24 @@ def __call__(self, net, images, condition=None, labels=None, augment_pipe=None): | |
augment_pipe(images) if augment_pipe is not None else (images, None) | ||
) | ||
n = torch.randn_like(y) * sigma | ||
additional_labels = { | ||
"augment_labels": augment_labels, | ||
"lead_time_label": lead_time_label, | ||
} | ||
# drop None items to support models that don't have these arguments in `forward` | ||
additional_labels = { | ||
k: v for (k, v) in additional_labels.items() if v is not None | ||
} | ||
if condition is not None: | ||
D_yn = net( | ||
y + n, | ||
sigma, | ||
condition=condition, | ||
class_labels=labels, | ||
augment_labels=augment_labels, | ||
**additional_labels, | ||
jleinonen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
else: | ||
D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) | ||
D_yn = net(y + n, sigma, labels, **additional_labels) | ||
loss = weight * ((D_yn - y) ** 2) | ||
return loss | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -666,7 +666,7 @@ class SongUNetPosEmbd(SongUNet): | |
N_grid_channels : int, optional, default=4 | ||
Number of channels :math:`C_{PE}` in the positional embedding grid. For 'sinusoidal' must be 4 or | ||
multiple of 4. For 'linear' and 'test' must be 2. For 'learnable' can be any | ||
value. | ||
value. If 0, positional embedding is disabled (but `lead_time_mode` may still be used). | ||
lead_time_mode : bool, optional, default=False | ||
Provided for convenience. It is recommended to use the architecture | ||
:class:`~physicsnemo.models.diffusion.song_unet.SongUNetPosLtEmbd` | ||
|
@@ -794,7 +794,7 @@ def __init__( | |
profile_mode: bool = False, | ||
amp_mode: bool = False, | ||
lead_time_mode: bool = False, | ||
lead_time_channels: int = None, | ||
lead_time_channels: int | None = None, | ||
lead_time_steps: int = 9, | ||
prob_channels: List[int] = [], | ||
): | ||
|
@@ -826,12 +826,16 @@ def __init__( | |
|
||
self.gridtype = gridtype | ||
self.N_grid_channels = N_grid_channels | ||
if self.gridtype == "learnable": | ||
if (self.gridtype == "learnable") or (self.N_grid_channels == 0): | ||
jleinonen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.pos_embd = self._get_positional_embedding() | ||
else: | ||
self.register_buffer("pos_embd", self._get_positional_embedding().float()) | ||
self.lead_time_mode = lead_time_mode | ||
if self.lead_time_mode: | ||
if (lead_time_channels is None) or (lead_time_channels <= 0): | ||
raise ValueError( | ||
"`lead_time_channels` must be >= 1 if `lead_time_mode` is enabled." | ||
) | ||
Comment on lines
+835
to
+838
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with this validation. However, there is clearly redundancy between If your concern is about backward compatibility if removing
if self.lead_time_mode:
warnings.warn(
f"The parameter `lead_time_mode` will be deprecated in a future version. "
f"The recommended way to enable (disable) lead-time embeddings is to set `lead_time_channels > 0` (`lead_time_channels = 0`)")
[...] |
||
self.lead_time_channels = lead_time_channels | ||
self.lead_time_steps = lead_time_steps | ||
self.lt_embd = self._get_lead_time_embedding() | ||
|
@@ -840,6 +844,12 @@ def __init__( | |
self.scalar = torch.nn.Parameter( | ||
torch.ones((1, len(self.prob_channels), 1, 1)) | ||
) | ||
else: | ||
if lead_time_channels: | ||
raise ValueError( | ||
"When `lead_time_mode` is disabled, `lead_time_channels` may not be set." | ||
) | ||
self.lt_embd = None | ||
|
||
def forward( | ||
self, | ||
|
@@ -861,11 +871,11 @@ def forward( | |
"Cannot provide both embedding_selector and global_index." | ||
) | ||
|
||
if x.dtype != self.pos_embd.dtype: | ||
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype): | ||
self.pos_embd = self.pos_embd.to(x.dtype) | ||
|
||
# Append positional embedding to input conditioning | ||
if self.pos_embd is not None: | ||
if (self.pos_embd is not None) or (self.lt_embd is not None): | ||
# Select positional embeddings with a selector function | ||
if embedding_selector is not None: | ||
selected_pos_embd = self.positional_embedding_selector( | ||
|
@@ -905,8 +915,13 @@ def positional_embedding_indexing( | |
r"""Select positional embeddings using global indices. | ||
|
||
This method uses global indices to select specific subset of the | ||
positional embedding grid (called *patches*). If no indices are provided, | ||
the entire positional embedding grid is returned. | ||
positional embedding grid and/or the lead-time embedding grid (called | ||
*patches*). If no indices are provided, the entire embedding grid is returned. | ||
The positional embedding grid is returned if `N_grid_channels > 0`, while | ||
the lead-time embedding grid is returned if `lead_time_mode == True`. If | ||
both positional and lead-time embedding are enabled, both are returned | ||
(concatenated). If neither is enabled, this function should not be called; | ||
doing so will raise a ValueError. | ||
|
||
Parameters | ||
---------- | ||
|
@@ -918,15 +933,24 @@ def positional_embedding_indexing( | |
the patches to extract from the positional embedding grid. | ||
:math:`P` is the number of distinct patches in the input tensor ``x``. | ||
The channel dimension should contain :math:`j`, :math:`i` indices that | ||
should represent the indices of the pixels to extract from the embedding grid. | ||
should represent the indices of the pixels to extract from the | ||
embedding grid. | ||
lead_time_label : Optional[torch.Tensor], default=None | ||
Tensor of shape :math:`(P,)` that corresponds to the lead-time label for each patch. | ||
Only used if ``lead_time_mode`` is True. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
Selected positional embeddings with shape :math:`(P \times B, C_{PE}, H_{in}, W_{in})` | ||
(same spatial resolution as ``global_index``) if ``global_index`` is provided. | ||
If ``global_index`` is None, the entire positional embedding grid | ||
is duplicated :math:`B` times and returned with shape :math:`(B, C_{PE}, H, W)`. | ||
Selected embeddings with shape :math:`(P \times B, C_{PE} [+ | ||
C_{LT}], H_{in}, W_{in})`. :math:`C_{PE}` is the number of | ||
embedding channels in the positional embedding grid, and | ||
:math:`C_{LT}` is the number of embedding channels in the lead-time | ||
embedding grid. If ``lead_time_label`` is provided, the lead-time | ||
embedding channels are included. If ``global_index`` is `None`, | ||
:math:`P = 1` is assumed, and the positional embedding grid is | ||
duplicated :math:`B` times and returned with shape | ||
:math:`(B, C_{PE} [+ C_{LT}], H, W)`. | ||
|
||
Example | ||
------- | ||
|
@@ -951,7 +975,7 @@ def positional_embedding_indexing( | |
""" | ||
# If no global indices are provided, select all embeddings and expand | ||
# to match the batch size of the input | ||
if x.dtype != self.pos_embd.dtype: | ||
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype): | ||
self.pos_embd = self.pos_embd.to(x.dtype) | ||
|
||
if global_index is None: | ||
|
@@ -989,23 +1013,26 @@ def positional_embedding_indexing( | |
global_index = torch.reshape( | ||
torch.permute(global_index, (1, 0, 2, 3)), (2, -1) | ||
) # (P, 2, X, Y) to (2, P*X*Y) | ||
selected_pos_embd = self.pos_embd[ | ||
:, global_index[0], global_index[1] | ||
] # (C_pe, P*X*Y) | ||
selected_pos_embd = torch.permute( | ||
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), | ||
(1, 0, 2, 3), | ||
) # (P, C_pe, X, Y) | ||
|
||
selected_pos_embd = selected_pos_embd.repeat( | ||
B, 1, 1, 1 | ||
) # (B*P, C_pe, X, Y) | ||
|
||
if self.pos_embd is not None: | ||
selected_pos_embd = self.pos_embd[ | ||
:, global_index[0], global_index[1] | ||
] # (C_pe, P*X*Y) | ||
selected_pos_embd = torch.permute( | ||
torch.reshape(selected_pos_embd, (self.pos_embd.shape[0], P, H, W)), | ||
(1, 0, 2, 3), | ||
) # (P, C_pe, X, Y) | ||
|
||
selected_pos_embd = selected_pos_embd.repeat( | ||
B, 1, 1, 1 | ||
) # (B*P, C_pe, X, Y) | ||
|
||
embeds = [selected_pos_embd] | ||
else: | ||
embeds = [] | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Append positional and lead time embeddings to input conditioning | ||
if self.lead_time_mode: | ||
embeds = [] | ||
if self.pos_embd is not None: | ||
embeds.append(selected_pos_embd) # reuse code below | ||
if self.lt_embd is not None: | ||
lt_embds = self.lt_embd[ | ||
lead_time_label.int() | ||
|
@@ -1026,8 +1053,12 @@ def positional_embedding_indexing( | |
) # (B*P, C_pe, X, Y) | ||
embeds.append(selected_lt_pos_embd) | ||
|
||
if len(embeds) > 0: | ||
selected_pos_embd = torch.cat(embeds, dim=1) | ||
if len(embeds) > 0: | ||
selected_pos_embd = torch.cat(embeds, dim=1) | ||
jleinonen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
raise ValueError( | ||
"`positional_embedding_indexing` should not be called when neither lead-time nor positional embeddings are used." | ||
) | ||
|
||
return selected_pos_embd | ||
|
||
|
@@ -1090,15 +1121,24 @@ def positional_embedding_selector( | |
... return patching.apply(emb[None].expand(batch_size, -1, -1, -1)) | ||
>>> | ||
""" | ||
if x.dtype != self.pos_embd.dtype: | ||
if (self.pos_embd is not None) and (x.dtype != self.pos_embd.dtype): | ||
CharlelieLrt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.pos_embd = self.pos_embd.to(x.dtype) | ||
if lead_time_label is not None: | ||
|
||
if self.pos_embd is not None and lead_time_label is not None: | ||
# both positional and lead-time embedding | ||
# all patches share same lead_time_label | ||
embeddings = torch.cat( | ||
[self.pos_embd, self.lt_embd[lead_time_label[0].int()]] | ||
) | ||
else: | ||
elif self.pos_embd is None: # positional embedding only | ||
embeddings = self.pos_embd | ||
elif lead_time_label is None: # lead time embedding only | ||
embeddings = self.lt_embd[lead_time_label[0].int()] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why using only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrelated, I think there is a second bug here. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I confirmed with @tge25 that the This |
||
else: | ||
raise ValueError( | ||
"`positional_embedding_selector` should not be called when neither lead-time nor positional embeddings are used." | ||
) | ||
|
||
return embedding_selector(embeddings) # (B, N_pe, H, W) | ||
|
||
def _get_positional_embedding(self): | ||
|
@@ -1331,7 +1371,7 @@ def __init__( | |
resample_filter: List[int] = [1, 1], | ||
gridtype: str = "sinusoidal", | ||
N_grid_channels: int = 4, | ||
lead_time_channels: int = None, | ||
lead_time_channels: int | None = None, | ||
lead_time_steps: int = 9, | ||
prob_channels: List[int] = [], | ||
checkpoint_level: int = 0, | ||
|
Uh oh!
There was an error while loading. Please reload this page.