Skip to content

[feature]: add spargeattn search in infer stage #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions configs/wan_t2v_sparge_tune.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false,
"sparge": true,
"sparse_tune": true,
"sparge_ckpt": "sparge_wan2.1_t2v_1.3B.pt"
}
17 changes: 17 additions & 0 deletions lightx2v/models/networks/wan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.utils.envs import *
from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER


class WanModel:
Expand Down Expand Up @@ -187,6 +188,22 @@ def infer(self, inputs):
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]

# Final infer stage
# sparge region start
self.sparge_tune = self.config.get("sparse_tune", False)
if self.sparge_tune:
saved_state_dict = {}
for k, v in self.transformer_weights.named_parameters():
if isinstance(v, ATTN_WEIGHT_REGISTER["Sparge"]):
for model_key, model_param in v.inner_cls.state_dict().items():
if k in model_key:
saved_state_dict[model_key] = model_param
# save to file
torch.save(saved_state_dict, self.config.get("sparse_ckpt", "sparse_tune.pt"))
else:
pass
# sparge region end

if self.config["feature_caching"] == "Tea":
self.scheduler.cnt += 1
if self.scheduler.cnt >= self.scheduler.num_steps:
Expand Down
8 changes: 7 additions & 1 deletion lightx2v/models/networks/wan/weights/transformer_weights.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, TENSOR_REGISTER, ATTN_WEIGHT_REGISTER
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
Expand Down Expand Up @@ -26,6 +27,7 @@ def __init__(self, block_index, task, mm_type, config):
self.config = config
self.quant_method = config["mm_config"].get("quant_method", None)
self.sparge = config.get("sparge", False)
self.sparge_tune = config.get("sparse_tune", False)

self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias"))
self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias"))
Expand Down Expand Up @@ -62,10 +64,14 @@ def __init__(self, block_index, task, mm_type, config):
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())

# load attn weights
if self.sparge:
if self.sparge and not self.sparge_tune:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt)
elif self.sparge_tune:
# enable tune mode
if not os.getenv("TUNE_MODE"):
os.environ["TUNE_MODE"] = "True"
else:
# do not load weights
pass
Expand Down
39 changes: 39 additions & 0 deletions scripts/run_wan_t2v_sparge_tune.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash

# set path and first
#! test with 1.3B
lightx2v_path=
model_path=

# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi

if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi

if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi

export TOKENIZERS_PARALLELISM=false

export PYTHONPATH=${lightx2v_path}:$PYTHONPATH

export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false

python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_t2v_sparge_tune.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4