@@ -234,7 +234,7 @@ class SongUNet(Module):
234234 architectures. Despite the name, these embeddings encode temporal information about the
235235 diffusion process rather than spatial position information.
236236 • Limitations on input image resolution: for a model that has :math:`N` levels,
237- the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^N ` in each dimension.
237+ the latent state :math:`\mathbf{x}` must have resolution that is a multiple of :math:`2^{N-1} ` in each dimension.
238238 This is due to a limitation in the decoder that does not support shape mismatch
239239 in the residual connections from the encoder to the decoder. For images that do not match
240240 this requirement, it is recommended to interpolate your data on a grid of the required resolution
@@ -337,7 +337,7 @@ def __init__(
337337 self .img_shape_x = img_resolution [1 ]
338338
339339 self ._num_levels = len (channel_mult )
340- self ._input_shape_mult = 2 ** self ._num_levels
340+ self ._input_shape_mult = 2 ** ( self ._num_levels - 1 )
341341
342342 # set the threshold for checkpointing based on image resolution
343343 self .checkpoint_threshold = (
@@ -534,7 +534,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None):
534534 f"got { x .ndim } D tensor with shape { tuple (x .shape )} "
535535 )
536536
537- # Check spatial dimensions are powers of 2 or multiples of 2^N
537+ # Check spatial dimensions are powers of 2 or multiples of 2^{N-1}
538538 for d in x .shape [- 2 :]:
539539 # Check if d is a power of 2
540540 is_power_of_2 = (d & (d - 1 )) == 0 and d > 0
@@ -545,7 +545,7 @@ def forward(self, x, noise_labels, class_labels, augment_labels=None):
545545 ):
546546 raise ValueError (
547547 f"Input spatial dimensions ({ x .shape [- 2 :]} ) must be "
548- f"either powers of 2 or multiples of 2**N where "
548+ f"either powers of 2 or multiples of 2**(N-1) where "
549549 f"N (={ self ._num_levels } ) is the number of levels "
550550 f"in the U-Net."
551551 )
0 commit comments