From b94a50037dd9125b01ea1c3cb3f5af79fae0553c Mon Sep 17 00:00:00 2001 From: Muhammad Nabi Yasinzai Date: Mon, 29 Sep 2025 19:23:48 +1300 Subject: [PATCH 1/3] Fix Bug in HaarWaveletTransform3D, Added correct output packing with k=8 --- causalvideovae/model/modules/wavelet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/causalvideovae/model/modules/wavelet.py b/causalvideovae/model/modules/wavelet.py index 891660a..8cf6366 100644 --- a/causalvideovae/model/modules/wavelet.py +++ b/causalvideovae/model/modules/wavelet.py @@ -77,7 +77,7 @@ def forward(self, x): outputs.append(self.gh_v_conv(y)) outputs = torch.cat(outputs, dim=0) - outputs = rearrange(outputs, "(b k c) 1 t h w -> b (c k) t h w", b=b, k=c) + outputs = rearrange(outputs, "(b k c) 1 t h w -> b (c k) t h w", b=b, c=c, k=8) return outputs class InverseHaarWaveletTransform3D(nn.Module): From 4cd0a7652d29acb08ab00329d58e60396d0cc0f5 Mon Sep 17 00:00:00 2001 From: Muhammad Nabi Yasinzai Date: Sat, 4 Oct 2025 20:34:58 +1300 Subject: [PATCH 2/3] Fixed incorrect order of (c k) packing --- causalvideovae/model/modules/wavelet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/causalvideovae/model/modules/wavelet.py b/causalvideovae/model/modules/wavelet.py index 8cf6366..23269e3 100644 --- a/causalvideovae/model/modules/wavelet.py +++ b/causalvideovae/model/modules/wavelet.py @@ -6,6 +6,9 @@ from einops import rearrange +def conv(): + return nn.Conv3d(1, 1, kernel_size=2, stride=2, padding=0, bias=False) + class HaarWaveletTransform3D(nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -77,7 +80,7 @@ def forward(self, x): outputs.append(self.gh_v_conv(y)) outputs = torch.cat(outputs, dim=0) - outputs = rearrange(outputs, "(b k c) 1 t h w -> b (c k) t h w", b=b, c=c, k=8) + outputs = rearrange(outputs, "(b c k) 1 t h w -> b (k c) t h w", b=b, c=c, k=8) return outputs class InverseHaarWaveletTransform3D(nn.Module): From e95c9ff3ec40fff338be9ef5c94927c3c4964024 Mon Sep 17 00:00:00 2001 From: Muhammad Nabi Yasinzai Date: Sat, 4 Oct 2025 20:39:04 +1300 Subject: [PATCH 3/3] Removed unused conv fun --- causalvideovae/model/modules/wavelet.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/causalvideovae/model/modules/wavelet.py b/causalvideovae/model/modules/wavelet.py index 23269e3..cbe39cf 100644 --- a/causalvideovae/model/modules/wavelet.py +++ b/causalvideovae/model/modules/wavelet.py @@ -5,9 +5,6 @@ from ..modules.ops import video_to_image from einops import rearrange - -def conv(): - return nn.Conv3d(1, 1, kernel_size=2, stride=2, padding=0, bias=False) class HaarWaveletTransform3D(nn.Module): def __init__(self, *args, **kwargs) -> None: