Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions physicsnemo/models/diffusion/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class SongUNet(Module):
architectures. Despite the name, these embeddings encode temporal information about the
diffusion process rather than spatial position information.
• Limitations on input image resolution: for a model that has :math:`N` levels,
the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^N` in each dimension.
the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^{N-1}` in each dimension.
This is due to a limitation in the decoder that does not support shape mismatch
in the residual connections from the encoder to the decoder. For images that do not match
this requirement, it is recommended to interpolate your data on a grid of the required resolution
Expand Down Expand Up @@ -337,7 +337,7 @@ def __init__(
self.img_shape_x = img_resolution[1]

self._num_levels = len(channel_mult)
self._input_shape_mult = 2**self._num_levels
self._input_shape_mult = 2 ** (self._num_levels - 1)

# set the threshold for checkpointing based on image resolution
self.checkpoint_threshold = (
Expand Down Expand Up @@ -534,7 +534,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None):
f"got {x.ndim}D tensor with shape {tuple(x.shape)}"
)

# Check spatial dimensions are powers of 2 or multiples of 2^N
# Check spatial dimensions are powers of 2 or multiples of 2^{N-1}
for d in x.shape[-2:]:
# Check if d is a power of 2
is_power_of_2 = (d & (d - 1)) == 0 and d > 0
Expand All @@ -545,7 +545,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None):
):
raise ValueError(
f"Input spatial dimensions ({x.shape[-2:]}) must be "
f"either powers of 2 or multiples of 2**N where "
f"either powers of 2 or multiples of 2**(N-1) where "
f"N (={self._num_levels}) is the number of levels "
f"in the U-Net."
)
Expand Down