From fa191a88debad926623aadc6ec0b2382c9cb39cb Mon Sep 17 00:00:00 2001 From: Charlelie Laurent Date: Thu, 13 Nov 2025 16:23:18 -0800 Subject: [PATCH] Fixed minor bug in shape validation in SongUNet Signed-off-by: Charlelie Laurent --- physicsnemo/models/diffusion/song_unet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/physicsnemo/models/diffusion/song_unet.py b/physicsnemo/models/diffusion/song_unet.py index 0a8ddc3fde..62d6e02175 100644 --- a/physicsnemo/models/diffusion/song_unet.py +++ b/physicsnemo/models/diffusion/song_unet.py @@ -234,7 +234,7 @@ class SongUNet(Module): architectures. Despite the name, these embeddings encode temporal information about the diffusion process rather than spatial position information. • Limitations on input image resolution: for a model that has :math:`N` levels, - the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^N` in each dimension. + the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^{N-1}` in each dimension. This is due to a limitation in the decoder that does not support shape mismatch in the residual connections from the encoder to the decoder. For images that do not match this requirement, it is recommended to interpolate your data on a grid of the required resolution @@ -337,7 +337,7 @@ def __init__( self.img_shape_x = img_resolution[1] self._num_levels = len(channel_mult) - self._input_shape_mult = 2**self._num_levels + self._input_shape_mult = 2 ** (self._num_levels - 1) # set the threshold for checkpointing based on image resolution self.checkpoint_threshold = ( @@ -534,7 +534,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None): f"got {x.ndim}D tensor with shape {tuple(x.shape)}" ) - # Check spatial dimensions are powers of 2 or multiples of 2^N + # Check spatial dimensions are powers of 2 or multiples of 2^{N-1} for d in x.shape[-2:]: # Check if d is a power of 2 is_power_of_2 = (d & (d - 1)) == 0 and d > 0 @@ -545,7 +545,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None): ): raise ValueError( f"Input spatial dimensions ({x.shape[-2:]}) must be " - f"either powers of 2 or multiples of 2**N where " + f"either powers of 2 or multiples of 2**(N-1) where " f"N (={self._num_levels}) is the number of levels " f"in the U-Net." )