Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,19 @@
*.html
*.pdf
*.whl
cache
*cache/
__pycache__/
storage/
samples/
!.gitignore
!requirements.txt
.DS_Store
*DS_Store
.vscode
google/
Wan2.1-T2V-14B/
Wan2.1-T2V-1.3B/
Wan2.1-I2V-14B-480P/
Wan2.1-I2V-14B-720P/
poetry.lock
poetry.lock
logs/
118 changes: 114 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,25 @@

import torch
import torch.distributed as dist
from torch.cuda import set_device
from PIL import Image

try:
import torch_musa
from torch_musa.core.device import set_device
except ModuleNotFoundError:
torch_musa = None

import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool
from wan.utils.platform import (
get_device_type,
get_torch_distributed_backend,
get_torch_profiler_activities,
)



EXAMPLE_PROMPT = {
Expand Down Expand Up @@ -243,6 +256,11 @@ def _parse_args():
type=float,
default=5.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="profile the generating procedure.")

args = parser.parse_args()

Expand All @@ -263,6 +281,30 @@ def _init_logging(rank):
logging.basicConfig(level=logging.ERROR)


def _init_profiler():
profiler = torch.profiler.profile(
activities=get_torch_profiler_activities(),
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
profiler.start()
return profiler


def _finalize_profiler(profiler):
profiler.stop()
table = profiler.key_averages().table(
sort_by=f"{get_device_type()}_time_total",
row_limit=20,
)
file_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
with open(f"logs/profiling-{file_name}.txt", "w") as f:
f.write(table)
del file_name


def generate(args):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
Expand All @@ -275,9 +317,9 @@ def generate(args):
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(local_rank)
set_device(local_rank)
dist.init_process_group(
backend="nccl",
backend=get_torch_distributed_backend(),
init_method="env://",
rank=rank,
world_size=world_size)
Expand Down Expand Up @@ -329,6 +371,10 @@ def generate(args):
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]

profiler = None
if args.profile and rank == 0:
profiler = _init_profiler()

if "t2v" in args.task or "t2i" in args.task:
if args.prompt is None:
Expand Down Expand Up @@ -366,10 +412,23 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler,
)

logging.info(
f"Generating {'image' if 't2i' in args.task else 'video'} ...")
logging.info("Warming up WanT2V pipeline ...")
with torch.no_grad():
_ = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=3,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)

logging.info(f"Generating {'image' if 't2i' in args.task else 'video'} ...")
video = wan_t2v.generate(
args.prompt,
size=SIZE_CONFIGS[args.size],
Expand Down Expand Up @@ -423,8 +482,23 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler,
)

logging.info("Warming up WanI2V pipeline ...")
with torch.no_grad():
_ = wan_i2v.generate(
args.prompt,
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=3,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)

logging.info("Generating video ...")
video = wan_i2v.generate(
args.prompt,
Expand Down Expand Up @@ -481,8 +555,24 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler
)

logging.info("Warming up WanFLF2V pipeline ...")
with torch.no_grad():
_ = wan_flf2v.generate(
args.prompt,
first_frame,
last_frame,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=3,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)

logging.info("Generating video ...")
video = wan_flf2v.generate(
args.prompt,
Expand Down Expand Up @@ -529,6 +619,7 @@ def generate(args):
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
profiler=profiler
)

src_video, src_mask, src_ref_images = wan_vace.prepare_source(
Expand All @@ -537,6 +628,22 @@ def generate(args):
args.src_ref_images.split(',')
], args.frame_num, SIZE_CONFIGS[args.size], device)

logging.info("Warming up VACE pipeline ...")
with torch.no_grad():
_ = wan_vace.generate(
args.prompt,
src_video,
src_mask,
src_ref_images,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=3,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)

logging.info(f"Generating video...")
video = wan_vace.generate(
args.prompt,
Expand All @@ -554,6 +661,9 @@ def generate(args):
else:
raise ValueError(f"Unkown task type: {args.task}")

if args.profile and rank == 0:
_finalize_profiler(profiler)

if rank == 0:
if args.save_file is None:
formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
Expand Down
8 changes: 7 additions & 1 deletion wan/distributed/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from functools import partial

import torch
from torch.cuda import empty_cache
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from torch.distributed.utils import _free_storage

try:
import torch_musa
from torch_musa.core.memory import empty_cache
except ModuleNotFoundError:
torch_musa = None

def shard_model(
model,
Expand Down Expand Up @@ -40,4 +46,4 @@ def free_model(model):
_free_storage(m._handle.flat_param.data)
del model
gc.collect()
torch.cuda.empty_cache()
empty_cache()
45 changes: 30 additions & 15 deletions wan/distributed/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@
get_sequence_parallel_world_size,
get_sp_group,
)
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
from xfuser.core.long_ctx_attention import xFuserLongContextAttention, AttnType
attn_type:AttnType = AttnType.FA

from wan.modules.rope import rope_apply_pytorch, rope_apply_triton

try:
import torch_musa
import torch_musa.core.amp as amp
attn_type = AttnType.TORCH
torch.backends.mudnn.allow_tf32 = True
except ImportError:
torch_musa = None

from ..modules.model import sinusoidal_embedding_1d

Expand All @@ -25,7 +36,7 @@ def pad_freqs(original_tensor, target_len):


@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
def rope_apply(x, grid_sizes, freqs, sp_size, sp_rank):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
Expand All @@ -51,8 +62,6 @@ def rope_apply(x, grid_sizes, freqs):
dim=-1).reshape(seq_len, 1, -1)

# apply rotary embedding
sp_size = get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
freqs_i = pad_freqs(freqs_i, s * sp_size)
s_per_rank = s
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
Expand Down Expand Up @@ -109,9 +118,13 @@ def usp_dit_forward(
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
dtype = self.patch_embedding.weight.dtype
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if self.freqs[0].dtype != dtype or self.freqs[0].device != device:
self.freqs = (
self.freqs[0].to(dtype=dtype, device=device),
self.freqs[-1].to(dtype=dtype, device=device)
)

if self.model_type != 'vace' and y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
Expand All @@ -129,11 +142,9 @@ def usp_dit_forward(
])

# time embeddings
with amp.autocast(dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float())
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
assert e.dtype == torch.float32 and e0.dtype == torch.float32
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))

# context
context_lens = None
Expand Down Expand Up @@ -177,7 +188,7 @@ def usp_dit_forward(

# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
return x


def usp_attn_forward(self,
Expand All @@ -200,8 +211,12 @@ def qkv_fn(x):
return q, k, v

q, k, v = qkv_fn(x)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
if torch_musa is None:
q = rope_apply(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
k = rope_apply(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
else:
q = rope_apply_pytorch(q, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())
k = rope_apply_pytorch(k, grid_sizes, freqs, get_sequence_parallel_world_size(), get_sequence_parallel_rank())

# TODO: We should use unpaded q,k,v for attention.
# k_lens = seq_lens // get_sequence_parallel_world_size()
Expand All @@ -210,7 +225,7 @@ def qkv_fn(x):
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)

x = xFuserLongContextAttention()(
x = xFuserLongContextAttention(attn_type=attn_type)(
None,
query=half(q),
key=half(k),
Expand Down
Loading