diff --git a/wan/distributed/sequence_parallel.py b/wan/distributed/sequence_parallel.py index 9c1ad786..90dbbc7c 100644 --- a/wan/distributed/sequence_parallel.py +++ b/wan/distributed/sequence_parallel.py @@ -68,11 +68,13 @@ def sp_dit_forward( context, seq_len, y=None, + mask=None, ): """ x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C]. + mask: Time embeddings mask tensor """ if self.model_type == 'i2v': assert y is not None @@ -97,17 +99,42 @@ def sp_dit_forward( ]) # time embeddings - if t.dim() == 1: - t = t.expand(t.size(0), seq_len) + assert t.dim() == 1 + assert t.size(0) == 1 + world_size = get_world_size() with torch.amp.autocast('cuda', dtype=torch.float32): bt = t.size(0) - t = t.flatten() e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, - t).unflatten(0, (bt, seq_len)).float()) + t).unflatten(0, (bt, 1)).float()) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) assert e.dtype == torch.float32 and e0.dtype == torch.float32 + if mask is None: + seq_len_tmp = seq_len // world_size + e = e.repeat(1, seq_len_tmp, 1) + e0 = e0.repeat(1, seq_len_tmp, 1, 1) + else: + if self.e_zero is None or self.e0_zero is None: + t_zero = torch.tensor([0], device=device) + with torch.amp.autocast('cuda', dtype=torch.float32): + self.e_zero = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, + t_zero).unflatten(0, (1, 1)).float()) + self.e0_zero = self.time_projection(self.e_zero).unflatten(2, (6, self.dim)) + assert self.e_zero.dtype == torch.float32 and self.e0_zero.dtype == torch.float32 + + e = e.repeat(1, seq_len, 1) + e0 = e0.repeat(1, seq_len, 1, 1) + + zero_mask = (mask == 0) + if zero_mask.any(): + e[:, zero_mask, :] = self.e_zero + e0[:, zero_mask, :] = self.e0_zero + + e = torch.chunk(e, world_size, dim=1)[get_rank()] + e0 = torch.chunk(e0, world_size, dim=1)[get_rank()] + # context context_lens = None context = self.text_embedding( @@ -118,8 +145,6 @@ def sp_dit_forward( # Context Parallel x = torch.chunk(x, get_world_size(), dim=1)[get_rank()] - e = torch.chunk(e, get_world_size(), dim=1)[get_rank()] - e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()] # arguments kwargs = dict( diff --git a/wan/modules/model.py b/wan/modules/model.py index 6982fa15..3bccb331 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -404,6 +404,9 @@ def __init__(self, ], dim=1) + self.e_zero = None + self.e0_zero = None + # initialize weights self.init_weights() @@ -414,6 +417,7 @@ def forward( context, seq_len, y=None, + mask=None, ): r""" Forward pass through the diffusion model @@ -429,6 +433,8 @@ def forward( Maximum sequence length for positional encoding y (List[Tensor], *optional*): Conditional video inputs for image-to-video mode, same shape as x + mask (Tensor): + Time embeddings mask tensor Returns: List[Tensor]: @@ -457,17 +463,33 @@ def forward( ]) # time embeddings - if t.dim() == 1: - t = t.expand(t.size(0), seq_len) + assert t.dim() == 1 + assert t.size(0) == 1 with torch.amp.autocast('cuda', dtype=torch.float32): bt = t.size(0) - t = t.flatten() e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, - t).unflatten(0, (bt, seq_len)).float()) + t).unflatten(0, (bt, 1)).float()) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) assert e.dtype == torch.float32 and e0.dtype == torch.float32 + e = e.repeat(1, seq_len, 1) + e0 = e0.repeat(1, seq_len, 1, 1) + + if mask is not None: + if self.e_zero is None or self.e0_zero is None: + t_zero = torch.tensor([0], device=device) + with torch.amp.autocast('cuda', dtype=torch.float32): + self.e_zero = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, + t_zero).unflatten(0, (1, 1)).float()) + self.e0_zero = self.time_projection(self.e_zero).unflatten(2, (6, self.dim)) + assert self.e_zero.dtype == torch.float32 and self.e0_zero.dtype == torch.float32 + + zero_mask = (mask == 0) + if zero_mask.any(): + e[:, zero_mask, :] = self.e_zero + e0[:, zero_mask, :] = self.e0_zero # context context_lens = None context = self.text_embedding( diff --git a/wan/textimage2video.py b/wan/textimage2video.py index 67e9fd29..696385a1 100644 --- a/wan/textimage2video.py +++ b/wan/textimage2video.py @@ -355,7 +355,6 @@ def noop_no_sync(): # sample videos latents = noise - mask1, mask2 = masks_like(noise, zero=False) arg_c = {'context': context, 'seq_len': seq_len} arg_null = {'context': context_null, 'seq_len': seq_len} @@ -370,13 +369,6 @@ def noop_no_sync(): timestep = torch.stack(timestep) - temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() - temp_ts = torch.cat([ - temp_ts, - temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep - ]) - timestep = temp_ts.unsqueeze(0) - noise_pred_cond = self.model( latent_model_input, t=timestep, **arg_c)[0] noise_pred_uncond = self.model( @@ -570,19 +562,18 @@ def noop_no_sync(): timestep = torch.stack(timestep).to(self.device) - temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten() - temp_ts = torch.cat([ - temp_ts, - temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep + mask = (mask2[0][0][:, ::2, ::2]).flatten() + mask = torch.cat([ + mask, + mask.new_ones(seq_len - mask.size(0)) ]) - timestep = temp_ts.unsqueeze(0) noise_pred_cond = self.model( - latent_model_input, t=timestep, **arg_c)[0] + latent_model_input, t=timestep, mask=mask, **arg_c)[0] if offload_model: torch.cuda.empty_cache() noise_pred_uncond = self.model( - latent_model_input, t=timestep, **arg_null)[0] + latent_model_input, t=timestep, mask=mask,**arg_null)[0] if offload_model: torch.cuda.empty_cache() noise_pred = noise_pred_uncond + guide_scale * (