Skip to content

Commit 1a52284

Browse files
authored
Fixed minor bug in shape validation in SongUNet (#1230)
Signed-off-by: Charlelie Laurent <[email protected]>
1 parent adc6602 commit 1a52284

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

physicsnemo/models/diffusion/song_unet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ class SongUNet(Module):
234234
architectures. Despite the name, these embeddings encode temporal information about the
235235
diffusion process rather than spatial position information.
236236
• Limitations on input image resolution: for a model that has :math:`N` levels,
237-
the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^N` in each dimension.
237+
the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^{N-1}` in each dimension.
238238
This is due to a limitation in the decoder that does not support shape mismatch
239239
in the residual connections from the encoder to the decoder. For images that do not match
240240
this requirement, it is recommended to interpolate your data on a grid of the required resolution
@@ -337,7 +337,7 @@ def __init__(
337337
self.img_shape_x = img_resolution[1]
338338

339339
self._num_levels = len(channel_mult)
340-
self._input_shape_mult = 2**self._num_levels
340+
self._input_shape_mult = 2 ** (self._num_levels - 1)
341341

342342
# set the threshold for checkpointing based on image resolution
343343
self.checkpoint_threshold = (
@@ -534,7 +534,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None):
534534
f"got {x.ndim}D tensor with shape {tuple(x.shape)}"
535535
)
536536

537-
# Check spatial dimensions are powers of 2 or multiples of 2^N
537+
# Check spatial dimensions are powers of 2 or multiples of 2^{N-1}
538538
for d in x.shape[-2:]:
539539
# Check if d is a power of 2
540540
is_power_of_2 = (d & (d - 1)) == 0 and d > 0
@@ -545,7 +545,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None):
545545
):
546546
raise ValueError(
547547
f"Input spatial dimensions ({x.shape[-2:]}) must be "
548-
f"either powers of 2 or multiples of 2**N where "
548+
f"either powers of 2 or multiples of 2**(N-1) where "
549549
f"N (={self._num_levels}) is the number of levels "
550550
f"in the U-Net."
551551
)

0 commit comments

Comments
 (0)