From 8822d2f06160ff8e9d1b03df3083139003e5160d Mon Sep 17 00:00:00 2001 From: "zhiwei.dong" <dongz.cn@outlook.com> Date: Sun, 11 May 2025 17:07:32 +0000 Subject: [PATCH 1/4] [feature]: add spargeattn search in infer stage --- configs/wan_t2v_sparge_tune.json | 16 ++++++++ lightx2v/models/networks/wan/model.py | 13 +++++++ .../wan/weights/transformer_weights.py | 8 +++- scripts/run_wan_t2v_sparge_tune.sh | 39 +++++++++++++++++++ 4 files changed, 75 insertions(+), 1 deletion(-) create mode 100755 configs/wan_t2v_sparge_tune.json create mode 100755 scripts/run_wan_t2v_sparge_tune.sh diff --git a/configs/wan_t2v_sparge_tune.json b/configs/wan_t2v_sparge_tune.json new file mode 100755 index 00000000..353cc9b2 --- /dev/null +++ b/configs/wan_t2v_sparge_tune.json @@ -0,0 +1,16 @@ +{ + "infer_steps": 50, + "target_video_length": 81, + "text_len": 512, + "target_height": 480, + "target_width": 832, + "attention_type": "flash_attn3", + "seed": 42, + "sample_guide_scale": 6, + "sample_shift": 8, + "enable_cfg": true, + "cpu_offload": false, + "sparge": true, + "sparse_tune": true, + "sparge_ckpt": "sparge_wan2.1_t2v_1.3B.pt" +} diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py index 8fcb298e..df7448af 100755 --- a/lightx2v/models/networks/wan/model.py +++ b/lightx2v/models/networks/wan/model.py @@ -187,6 +187,19 @@ def infer(self, inputs): x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out) noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] + # Final infer stage + # 1. extract super parameters + self.sparge_tune = self.config.get("sparse_tune", False) + if self.sparge_tune: + saved_state_dict = {} + for k, v in self.transformer_weights.named_parameters(): + if isinstance(v, SparseAttentionMeansim): + for model_key, model_param in v.state_dict().items(): + if k in model_key: + saved_state_dict[model_key] = model_param + # save to file + torch.save(saved_state_dict, self.config.get("sparse_ckpt", "sparse_tune.pth")) + if self.config["feature_caching"] == "Tea": self.scheduler.cnt += 1 if self.scheduler.cnt >= self.scheduler.num_steps: diff --git a/lightx2v/models/networks/wan/weights/transformer_weights.py b/lightx2v/models/networks/wan/weights/transformer_weights.py index 31831b68..3764ae9a 100755 --- a/lightx2v/models/networks/wan/weights/transformer_weights.py +++ b/lightx2v/models/networks/wan/weights/transformer_weights.py @@ -1,3 +1,4 @@ +import os import torch from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, TENSOR_REGISTER, ATTN_WEIGHT_REGISTER from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList @@ -26,6 +27,7 @@ def __init__(self, block_index, task, mm_type, config): self.config = config self.quant_method = config["mm_config"].get("quant_method", None) self.sparge = config.get("sparge", False) + self.sparge_tune = config.get("sparse_tune", False) self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias")) self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias")) @@ -62,10 +64,14 @@ def __init__(self, block_index, task, mm_type, config): self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) # load attn weights - if self.sparge: + if self.sparge and not self.config["sparge_tune"]: assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" sparge_ckpt = torch.load(self.config["sparge_ckpt"]) self.self_attn_1.load(sparge_ckpt) + elif self.config["sparge_tune"]: + # enable tune mode + if not os.getenv("TUNE_MODE"): + os.environ["TUNE_MODE"] = "True" else: # do not load weights pass diff --git a/scripts/run_wan_t2v_sparge_tune.sh b/scripts/run_wan_t2v_sparge_tune.sh new file mode 100755 index 00000000..96198af7 --- /dev/null +++ b/scripts/run_wan_t2v_sparge_tune.sh @@ -0,0 +1,39 @@ +#!/bin/bash + +# set path and first +#! test with 1.3B +lightx2v_path= +model_path= + +# check section +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=0 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable." + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +export TOKENIZERS_PARALLELISM=false + +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH + +export ENABLE_PROFILING_DEBUG=true +export ENABLE_GRAPH_MODE=false + +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan_t2v_sparge_tune.json \ +--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \ +--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4 From 875cf15d5d74a4dda58602f9bb16c3dd03af39e3 Mon Sep 17 00:00:00 2001 From: "zhiwei.dong" <dongz.cn@outlook.com> Date: Sun, 11 May 2025 17:11:29 +0000 Subject: [PATCH 2/4] [minor]: fix dict --- lightx2v/models/networks/wan/weights/transformer_weights.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightx2v/models/networks/wan/weights/transformer_weights.py b/lightx2v/models/networks/wan/weights/transformer_weights.py index 3764ae9a..d57d270d 100755 --- a/lightx2v/models/networks/wan/weights/transformer_weights.py +++ b/lightx2v/models/networks/wan/weights/transformer_weights.py @@ -64,11 +64,11 @@ def __init__(self, block_index, task, mm_type, config): self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]()) # load attn weights - if self.sparge and not self.config["sparge_tune"]: + if self.sparge and not self.sparge_tune: assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" sparge_ckpt = torch.load(self.config["sparge_ckpt"]) self.self_attn_1.load(sparge_ckpt) - elif self.config["sparge_tune"]: + elif self.sparge_tune: # enable tune mode if not os.getenv("TUNE_MODE"): os.environ["TUNE_MODE"] = "True" From 2852f228bfa99b3dfe2c53b1d227702f44e66cab Mon Sep 17 00:00:00 2001 From: "zhiwei.dong" <dongz.cn@outlook.com> Date: Sun, 11 May 2025 17:43:10 +0000 Subject: [PATCH 3/4] [minor]: fix import --- lightx2v/models/networks/wan/model.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py index df7448af..7b108a8d 100755 --- a/lightx2v/models/networks/wan/model.py +++ b/lightx2v/models/networks/wan/model.py @@ -21,6 +21,7 @@ import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap from lightx2v.utils.envs import * from loguru import logger +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER class WanModel: @@ -188,17 +189,20 @@ def infer(self, inputs): noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0] # Final infer stage - # 1. extract super parameters + # sparge region start self.sparge_tune = self.config.get("sparse_tune", False) if self.sparge_tune: saved_state_dict = {} for k, v in self.transformer_weights.named_parameters(): - if isinstance(v, SparseAttentionMeansim): - for model_key, model_param in v.state_dict().items(): + if isinstance(v, ATTN_WEIGHT_REGISTER['Sparge']): + for model_key, model_param in v.inner_cls.state_dict().items(): if k in model_key: saved_state_dict[model_key] = model_param # save to file - torch.save(saved_state_dict, self.config.get("sparse_ckpt", "sparse_tune.pth")) + torch.save(saved_state_dict, self.config.get("sparse_ckpt", "sparse_tune.pt")) + else: + pass + # sparge region end if self.config["feature_caching"] == "Tea": self.scheduler.cnt += 1 From 62cae041a957d549fc0476a1351cc5a251f9db62 Mon Sep 17 00:00:00 2001 From: "zhiwei.dong" <dongz.cn@outlook.com> Date: Sun, 11 May 2025 17:49:31 +0000 Subject: [PATCH 4/4] [minor]: fix format --- lightx2v/models/networks/wan/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py index 7b108a8d..b1bac3a0 100755 --- a/lightx2v/models/networks/wan/model.py +++ b/lightx2v/models/networks/wan/model.py @@ -194,7 +194,7 @@ def infer(self, inputs): if self.sparge_tune: saved_state_dict = {} for k, v in self.transformer_weights.named_parameters(): - if isinstance(v, ATTN_WEIGHT_REGISTER['Sparge']): + if isinstance(v, ATTN_WEIGHT_REGISTER["Sparge"]): for model_key, model_param in v.inner_cls.state_dict().items(): if k in model_key: saved_state_dict[model_key] = model_param