Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions wan/distributed/sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ def sp_dit_forward(
context,
seq_len,
y=None,
mask=None,
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
mask: Time embeddings mask tensor
"""
if self.model_type == 'i2v':
assert y is not None
Expand All @@ -97,17 +99,42 @@ def sp_dit_forward(
])

# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
assert t.dim() == 1
assert t.size(0) == 1
world_size = get_world_size()
with torch.amp.autocast('cuda', dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t).unflatten(0, (bt, seq_len)).float())
t).unflatten(0, (bt, 1)).float())
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32

if mask is None:
seq_len_tmp = seq_len // world_size
e = e.repeat(1, seq_len_tmp, 1)
e0 = e0.repeat(1, seq_len_tmp, 1, 1)
else:
if self.e_zero is None or self.e0_zero is None:
t_zero = torch.tensor([0], device=device)
with torch.amp.autocast('cuda', dtype=torch.float32):
self.e_zero = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t_zero).unflatten(0, (1, 1)).float())
self.e0_zero = self.time_projection(self.e_zero).unflatten(2, (6, self.dim))
assert self.e_zero.dtype == torch.float32 and self.e0_zero.dtype == torch.float32

e = e.repeat(1, seq_len, 1)
e0 = e0.repeat(1, seq_len, 1, 1)

zero_mask = (mask == 0)
if zero_mask.any():
e[:, zero_mask, :] = self.e_zero
e0[:, zero_mask, :] = self.e0_zero

e = torch.chunk(e, world_size, dim=1)[get_rank()]
e0 = torch.chunk(e0, world_size, dim=1)[get_rank()]

# context
context_lens = None
context = self.text_embedding(
Expand All @@ -118,8 +145,6 @@ def sp_dit_forward(

# Context Parallel
x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]

# arguments
kwargs = dict(
Expand Down
30 changes: 26 additions & 4 deletions wan/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ def __init__(self,
],
dim=1)

self.e_zero = None
self.e0_zero = None

# initialize weights
self.init_weights()

Expand All @@ -414,6 +417,7 @@ def forward(
context,
seq_len,
y=None,
mask=None,
):
r"""
Forward pass through the diffusion model
Expand All @@ -429,6 +433,8 @@ def forward(
Maximum sequence length for positional encoding
y (List[Tensor], *optional*):
Conditional video inputs for image-to-video mode, same shape as x
mask (Tensor):
Time embeddings mask tensor

Returns:
List[Tensor]:
Expand Down Expand Up @@ -457,17 +463,33 @@ def forward(
])

# time embeddings
if t.dim() == 1:
t = t.expand(t.size(0), seq_len)
assert t.dim() == 1
assert t.size(0) == 1
with torch.amp.autocast('cuda', dtype=torch.float32):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t).unflatten(0, (bt, seq_len)).float())
t).unflatten(0, (bt, 1)).float())
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32

e = e.repeat(1, seq_len, 1)
e0 = e0.repeat(1, seq_len, 1, 1)

if mask is not None:
if self.e_zero is None or self.e0_zero is None:
t_zero = torch.tensor([0], device=device)
with torch.amp.autocast('cuda', dtype=torch.float32):
self.e_zero = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim,
t_zero).unflatten(0, (1, 1)).float())
self.e0_zero = self.time_projection(self.e_zero).unflatten(2, (6, self.dim))
assert self.e_zero.dtype == torch.float32 and self.e0_zero.dtype == torch.float32

zero_mask = (mask == 0)
if zero_mask.any():
e[:, zero_mask, :] = self.e_zero
e0[:, zero_mask, :] = self.e0_zero
# context
context_lens = None
context = self.text_embedding(
Expand Down
21 changes: 6 additions & 15 deletions wan/textimage2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,6 @@ def noop_no_sync():

# sample videos
latents = noise
mask1, mask2 = masks_like(noise, zero=False)

arg_c = {'context': context, 'seq_len': seq_len}
arg_null = {'context': context_null, 'seq_len': seq_len}
Expand All @@ -370,13 +369,6 @@ def noop_no_sync():

timestep = torch.stack(timestep)

temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
temp_ts = torch.cat([
temp_ts,
temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
])
timestep = temp_ts.unsqueeze(0)

noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
noise_pred_uncond = self.model(
Expand Down Expand Up @@ -570,19 +562,18 @@ def noop_no_sync():

timestep = torch.stack(timestep).to(self.device)

temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
temp_ts = torch.cat([
temp_ts,
temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
mask = (mask2[0][0][:, ::2, ::2]).flatten()
mask = torch.cat([
mask,
mask.new_ones(seq_len - mask.size(0))
])
timestep = temp_ts.unsqueeze(0)

noise_pred_cond = self.model(
latent_model_input, t=timestep, **arg_c)[0]
latent_model_input, t=timestep, mask=mask, **arg_c)[0]
if offload_model:
torch.cuda.empty_cache()
noise_pred_uncond = self.model(
latent_model_input, t=timestep, **arg_null)[0]
latent_model_input, t=timestep, mask=mask,**arg_null)[0]
if offload_model:
torch.cuda.empty_cache()
noise_pred = noise_pred_uncond + guide_scale * (
Expand Down