Skip to content

Commit a88788d

Browse files
Wan 2.2 support. (Comfy-Org#9080)
1 parent d0210fe commit a88788d

8 files changed

Lines changed: 926 additions & 19 deletions

File tree

comfy/latent_formats.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,82 @@ def process_out(self, latent):
457457
latents_std = self.latents_std.to(latent.device, latent.dtype)
458458
return latent * latents_std / self.scale_factor + latents_mean
459459

460+
class Wan22(Wan21):
461+
latent_channels = 48
462+
latent_dimensions = 3
463+
464+
latent_rgb_factors = [
465+
[ 0.0119, 0.0103, 0.0046],
466+
[-0.1062, -0.0504, 0.0165],
467+
[ 0.0140, 0.0409, 0.0491],
468+
[-0.0813, -0.0677, 0.0607],
469+
[ 0.0656, 0.0851, 0.0808],
470+
[ 0.0264, 0.0463, 0.0912],
471+
[ 0.0295, 0.0326, 0.0590],
472+
[-0.0244, -0.0270, 0.0025],
473+
[ 0.0443, -0.0102, 0.0288],
474+
[-0.0465, -0.0090, -0.0205],
475+
[ 0.0359, 0.0236, 0.0082],
476+
[-0.0776, 0.0854, 0.1048],
477+
[ 0.0564, 0.0264, 0.0561],
478+
[ 0.0006, 0.0594, 0.0418],
479+
[-0.0319, -0.0542, -0.0637],
480+
[-0.0268, 0.0024, 0.0260],
481+
[ 0.0539, 0.0265, 0.0358],
482+
[-0.0359, -0.0312, -0.0287],
483+
[-0.0285, -0.1032, -0.1237],
484+
[ 0.1041, 0.0537, 0.0622],
485+
[-0.0086, -0.0374, -0.0051],
486+
[ 0.0390, 0.0670, 0.2863],
487+
[ 0.0069, 0.0144, 0.0082],
488+
[ 0.0006, -0.0167, 0.0079],
489+
[ 0.0313, -0.0574, -0.0232],
490+
[-0.1454, -0.0902, -0.0481],
491+
[ 0.0714, 0.0827, 0.0447],
492+
[-0.0304, -0.0574, -0.0196],
493+
[ 0.0401, 0.0384, 0.0204],
494+
[-0.0758, -0.0297, -0.0014],
495+
[ 0.0568, 0.1307, 0.1372],
496+
[-0.0055, -0.0310, -0.0380],
497+
[ 0.0239, -0.0305, 0.0325],
498+
[-0.0663, -0.0673, -0.0140],
499+
[-0.0416, -0.0047, -0.0023],
500+
[ 0.0166, 0.0112, -0.0093],
501+
[-0.0211, 0.0011, 0.0331],
502+
[ 0.1833, 0.1466, 0.2250],
503+
[-0.0368, 0.0370, 0.0295],
504+
[-0.3441, -0.3543, -0.2008],
505+
[-0.0479, -0.0489, -0.0420],
506+
[-0.0660, -0.0153, 0.0800],
507+
[-0.0101, 0.0068, 0.0156],
508+
[-0.0690, -0.0452, -0.0927],
509+
[-0.0145, 0.0041, 0.0015],
510+
[ 0.0421, 0.0451, 0.0373],
511+
[ 0.0504, -0.0483, -0.0356],
512+
[-0.0837, 0.0168, 0.0055]
513+
]
514+
515+
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
516+
517+
def __init__(self):
518+
self.scale_factor = 1.0
519+
self.latents_mean = torch.tensor([
520+
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
521+
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
522+
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
523+
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
524+
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
525+
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
526+
]).view(1, self.latent_channels, 1, 1, 1)
527+
self.latents_std = torch.tensor([
528+
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
529+
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
530+
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
531+
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
532+
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
533+
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
534+
]).view(1, self.latent_channels, 1, 1, 1)
535+
460536
class Hunyuan3Dv2(LatentFormat):
461537
latent_channels = 64
462538
latent_dimensions = 1

comfy/ldm/wan/model.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,10 @@ def forward(
201201
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
202202
"""
203203
# assert e.dtype == torch.float32
204-
205-
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
204+
if e.ndim < 4:
205+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
206+
else:
207+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
206208
# assert e[0].dtype == torch.float32
207209

208210
# self-attention
@@ -325,7 +327,10 @@ def forward(self, x, e):
325327
e(Tensor): Shape [B, C]
326328
"""
327329
# assert e.dtype == torch.float32
328-
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
330+
if e.ndim < 3:
331+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
332+
else:
333+
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
329334
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
330335
return x
331336

@@ -506,8 +511,9 @@ def forward_orig(
506511

507512
# time embeddings
508513
e = self.time_embedding(
509-
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
510-
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
514+
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
515+
e = e.reshape(t.shape[0], -1, e.shape[-1])
516+
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
511517

512518
# context
513519
context = self.text_embedding(context)

0 commit comments

Comments
 (0)