Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,12 @@ __pycache__/
tmp_examples*
new_checkpoint*
batch_test*
nohup*
nohup*
wan.egg-info
build
Wan2.2-T2V-A14B
2.4.0
1.23.5
Wan2.2-TI2V-5B
input
out
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ easydict
ftfy
dashscope
imageio-ffmpeg
flash_attn
#flash_attn
numpy>=1.23.5,<2
57 changes: 38 additions & 19 deletions wan/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn.functional as F

try:
import flash_attn_interface
Expand Down Expand Up @@ -109,25 +110,43 @@ def half(x):
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))

# output
return x.type(out_dtype)
# ---- SDPA Fallback (ohne Flash-Attn v2) ----
# q, k, v: (b*lq, h, d), (b*lk, h, d)
bld_q, h, d = q.shape
bld_k = k.shape[0]
lq = bld_q // b
lk = bld_k // b

# -> (b, h, L, d)
q_ = q.unflatten(0, (b, lq)).transpose(1, 2) # (b, h, lq, d)
k_ = k.unflatten(0, (b, lk)).transpose(1, 2) # (b, h, lk, d)
v_ = v.unflatten(0, (b, lk)).transpose(1, 2) # (b, h, lk, d)

if softmax_scale is not None:
q_ = q_ * softmax_scale

# bool-Maske: True = ignorieren
attn_mask = None
if (q_lens is not None) or (k_lens is not None):
attn_mask = torch.zeros((b, 1, lq, lk), dtype=torch.bool, device=q_.device)
if q_lens is not None:
for i in range(b):
if q_lens[i] < lq:
attn_mask[i, 0, q_lens[i]:, :] = True
if k_lens is not None:
for i in range(b):
if k_lens[i] < lk:
attn_mask[i, 0, :, k_lens[i]:] = True

x_sdpa = F.scaled_dot_product_attention(
q_, k_, v_,
attn_mask=attn_mask,
dropout_p=0.0,
is_causal=bool(causal),
) # (b, h, lq, d)

x = x_sdpa.permute(0, 2, 1, 3).contiguous()
return x


def attention(
Expand Down