From 8822d2f06160ff8e9d1b03df3083139003e5160d Mon Sep 17 00:00:00 2001
From: "zhiwei.dong" <dongz.cn@outlook.com>
Date: Sun, 11 May 2025 17:07:32 +0000
Subject: [PATCH 1/4] [feature]: add spargeattn search in infer stage

---
 configs/wan_t2v_sparge_tune.json              | 16 ++++++++
 lightx2v/models/networks/wan/model.py         | 13 +++++++
 .../wan/weights/transformer_weights.py        |  8 +++-
 scripts/run_wan_t2v_sparge_tune.sh            | 39 +++++++++++++++++++
 4 files changed, 75 insertions(+), 1 deletion(-)
 create mode 100755 configs/wan_t2v_sparge_tune.json
 create mode 100755 scripts/run_wan_t2v_sparge_tune.sh

diff --git a/configs/wan_t2v_sparge_tune.json b/configs/wan_t2v_sparge_tune.json
new file mode 100755
index 00000000..353cc9b2
--- /dev/null
+++ b/configs/wan_t2v_sparge_tune.json
@@ -0,0 +1,16 @@
+{
+    "infer_steps": 50,
+    "target_video_length": 81,
+    "text_len": 512,
+    "target_height": 480,
+    "target_width": 832,
+    "attention_type": "flash_attn3",
+    "seed": 42,
+    "sample_guide_scale": 6,
+    "sample_shift": 8,
+    "enable_cfg": true,
+    "cpu_offload": false,
+    "sparge": true,
+    "sparse_tune": true,
+    "sparge_ckpt": "sparge_wan2.1_t2v_1.3B.pt"
+}
diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py
index 8fcb298e..df7448af 100755
--- a/lightx2v/models/networks/wan/model.py
+++ b/lightx2v/models/networks/wan/model.py
@@ -187,6 +187,19 @@ def infer(self, inputs):
         x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out)
         noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
 
+        # Final infer stage
+        # 1. extract super parameters
+        self.sparge_tune = self.config.get("sparse_tune", False)
+        if self.sparge_tune:
+            saved_state_dict = {}
+            for k, v in self.transformer_weights.named_parameters():
+                if isinstance(v, SparseAttentionMeansim):
+                    for model_key, model_param in v.state_dict().items():
+                        if k in model_key:
+                            saved_state_dict[model_key] = model_param
+            # save to file
+            torch.save(saved_state_dict, self.config.get("sparse_ckpt", "sparse_tune.pth"))
+
         if self.config["feature_caching"] == "Tea":
             self.scheduler.cnt += 1
             if self.scheduler.cnt >= self.scheduler.num_steps:
diff --git a/lightx2v/models/networks/wan/weights/transformer_weights.py b/lightx2v/models/networks/wan/weights/transformer_weights.py
index 31831b68..3764ae9a 100755
--- a/lightx2v/models/networks/wan/weights/transformer_weights.py
+++ b/lightx2v/models/networks/wan/weights/transformer_weights.py
@@ -1,3 +1,4 @@
+import os
 import torch
 from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, TENSOR_REGISTER, ATTN_WEIGHT_REGISTER
 from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
@@ -26,6 +27,7 @@ def __init__(self, block_index, task, mm_type, config):
         self.config = config
         self.quant_method = config["mm_config"].get("quant_method", None)
         self.sparge = config.get("sparge", False)
+        self.sparge_tune = config.get("sparse_tune", False)
 
         self.add_module("self_attn_q", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias"))
         self.add_module("self_attn_k", MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.k.weight", f"blocks.{self.block_index}.self_attn.k.bias"))
@@ -62,10 +64,14 @@ def __init__(self, block_index, task, mm_type, config):
             self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
 
         # load attn weights
-        if self.sparge:
+        if self.sparge and not self.config["sparge_tune"]:
             assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
             sparge_ckpt = torch.load(self.config["sparge_ckpt"])
             self.self_attn_1.load(sparge_ckpt)
+        elif self.config["sparge_tune"]:
+            # enable tune mode
+            if not os.getenv("TUNE_MODE"):
+                os.environ["TUNE_MODE"] = "True"
         else:
             # do not load weights
             pass
diff --git a/scripts/run_wan_t2v_sparge_tune.sh b/scripts/run_wan_t2v_sparge_tune.sh
new file mode 100755
index 00000000..96198af7
--- /dev/null
+++ b/scripts/run_wan_t2v_sparge_tune.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+# set path and first
+#! test with 1.3B
+lightx2v_path=
+model_path=
+
+# check section
+if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
+    cuda_devices=0
+    echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
+    export CUDA_VISIBLE_DEVICES=${cuda_devices}
+fi
+
+if [ -z "${lightx2v_path}" ]; then
+    echo "Error: lightx2v_path is not set. Please set this variable first."
+    exit 1
+fi
+
+if [ -z "${model_path}" ]; then
+    echo "Error: model_path is not set. Please set this variable first."
+    exit 1
+fi
+
+export TOKENIZERS_PARALLELISM=false
+
+export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
+
+export ENABLE_PROFILING_DEBUG=true
+export ENABLE_GRAPH_MODE=false
+
+python -m lightx2v.infer \
+--model_cls wan2.1 \
+--task t2v \
+--model_path $model_path \
+--config_json ${lightx2v_path}/configs/wan_t2v_sparge_tune.json \
+--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
+--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
+--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v.mp4

From 875cf15d5d74a4dda58602f9bb16c3dd03af39e3 Mon Sep 17 00:00:00 2001
From: "zhiwei.dong" <dongz.cn@outlook.com>
Date: Sun, 11 May 2025 17:11:29 +0000
Subject: [PATCH 2/4] [minor]: fix dict

---
 lightx2v/models/networks/wan/weights/transformer_weights.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/lightx2v/models/networks/wan/weights/transformer_weights.py b/lightx2v/models/networks/wan/weights/transformer_weights.py
index 3764ae9a..d57d270d 100755
--- a/lightx2v/models/networks/wan/weights/transformer_weights.py
+++ b/lightx2v/models/networks/wan/weights/transformer_weights.py
@@ -64,11 +64,11 @@ def __init__(self, block_index, task, mm_type, config):
             self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["attention_type"]]())
 
         # load attn weights
-        if self.sparge and not self.config["sparge_tune"]:
+        if self.sparge and not self.sparge_tune:
             assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
             sparge_ckpt = torch.load(self.config["sparge_ckpt"])
             self.self_attn_1.load(sparge_ckpt)
-        elif self.config["sparge_tune"]:
+        elif self.sparge_tune:
             # enable tune mode
             if not os.getenv("TUNE_MODE"):
                 os.environ["TUNE_MODE"] = "True"

From 2852f228bfa99b3dfe2c53b1d227702f44e66cab Mon Sep 17 00:00:00 2001
From: "zhiwei.dong" <dongz.cn@outlook.com>
Date: Sun, 11 May 2025 17:43:10 +0000
Subject: [PATCH 3/4] [minor]: fix import

---
 lightx2v/models/networks/wan/model.py | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py
index df7448af..7b108a8d 100755
--- a/lightx2v/models/networks/wan/model.py
+++ b/lightx2v/models/networks/wan/model.py
@@ -21,6 +21,7 @@
 import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
 from lightx2v.utils.envs import *
 from loguru import logger
+from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
 
 
 class WanModel:
@@ -188,17 +189,20 @@ def infer(self, inputs):
         noise_pred_cond = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
 
         # Final infer stage
-        # 1. extract super parameters
+        # sparge region start
         self.sparge_tune = self.config.get("sparse_tune", False)
         if self.sparge_tune:
             saved_state_dict = {}
             for k, v in self.transformer_weights.named_parameters():
-                if isinstance(v, SparseAttentionMeansim):
-                    for model_key, model_param in v.state_dict().items():
+                if isinstance(v, ATTN_WEIGHT_REGISTER['Sparge']):
+                    for model_key, model_param in v.inner_cls.state_dict().items():
                         if k in model_key:
                             saved_state_dict[model_key] = model_param
             # save to file
-            torch.save(saved_state_dict, self.config.get("sparse_ckpt", "sparse_tune.pth"))
+            torch.save(saved_state_dict, self.config.get("sparse_ckpt", "sparse_tune.pt"))
+        else:
+            pass
+        # sparge region end
 
         if self.config["feature_caching"] == "Tea":
             self.scheduler.cnt += 1

From 62cae041a957d549fc0476a1351cc5a251f9db62 Mon Sep 17 00:00:00 2001
From: "zhiwei.dong" <dongz.cn@outlook.com>
Date: Sun, 11 May 2025 17:49:31 +0000
Subject: [PATCH 4/4] [minor]: fix format

---
 lightx2v/models/networks/wan/model.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/lightx2v/models/networks/wan/model.py b/lightx2v/models/networks/wan/model.py
index 7b108a8d..b1bac3a0 100755
--- a/lightx2v/models/networks/wan/model.py
+++ b/lightx2v/models/networks/wan/model.py
@@ -194,7 +194,7 @@ def infer(self, inputs):
         if self.sparge_tune:
             saved_state_dict = {}
             for k, v in self.transformer_weights.named_parameters():
-                if isinstance(v, ATTN_WEIGHT_REGISTER['Sparge']):
+                if isinstance(v, ATTN_WEIGHT_REGISTER["Sparge"]):
                     for model_key, model_param in v.inner_cls.state_dict().items():
                         if k in model_key:
                             saved_state_dict[model_key] = model_param