Skip to content

Commit 1741e95

Browse files
authored
add attn_implementation in model config;fix deepseekv3.1-terminus load error (#130)
1 parent 70e4cbc commit 1741e95

File tree

8 files changed

+56
-12
lines changed

8 files changed

+56
-12
lines changed

angelslim/engine.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def prepare_model(
7575
deploy_backend="vllm",
7676
using_multi_nodes=False,
7777
use_audio_in_video=False,
78+
attn_implementation="default",
7879
) -> Any:
7980
"""Load pretrained model and tokenizer
8081
Args:
@@ -92,6 +93,8 @@ def prepare_model(
9293
cache_dir (str, optional): Directory to cache the model.
9394
deploy_backend (str): Backend for deployment, e.g., "torch", "vllm".
9495
using_multi_nodes (bool): Whether to use multi-nodes for calibration.
96+
use_audio_in_video (bool): Whether to add audio track to a video file.
97+
attn_implementation (str): The attention implementation to use in the model.
9598
"""
9699
assert model_name, "model_name must be specified."
97100
assert model_path, "model_path must be specified."
@@ -126,6 +129,7 @@ def prepare_model(
126129
device_map=device_map,
127130
trust_remote_code=trust_remote_code,
128131
use_audio_in_video=use_audio_in_video,
132+
attn_implementation=attn_implementation,
129133
)
130134
self.model_path = model_path
131135
else:

angelslim/models/llm/modeling_deepseek.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import torch
2828
import torch.distributed as dist
2929
import torch.nn.functional as F
30-
from safetensors.torch import load_model, safe_open, save_file
30+
from safetensors.torch import load_file, load_model, safe_open, save_file
3131
from torch import nn
3232
from tqdm import tqdm, trange
3333
from transformers.generation import GenerationMixin
@@ -1037,12 +1037,28 @@ def from_pretrained(
10371037
dist.barrier()
10381038
with torch.device("cuda"):
10391039
model = cls(config)
1040-
load_model(
1041-
model,
1042-
os.path.join(
1040+
try:
1041+
load_model(
1042+
model,
1043+
os.path.join(
1044+
tp_model_path, f"model{rank}-mp{cls.world_size}.safetensors"
1045+
),
1046+
)
1047+
except RuntimeError:
1048+
file_path = os.path.join(
10431049
tp_model_path, f"model{rank}-mp{cls.world_size}.safetensors"
1044-
),
1045-
)
1050+
)
1051+
file_state_dict = load_file(file_path)
1052+
model_state_dict = model.state_dict()
1053+
for key in model_state_dict:
1054+
if (
1055+
key in file_state_dict
1056+
and file_state_dict[key].dtype != model_state_dict[key].dtype
1057+
):
1058+
file_state_dict[key] = file_state_dict[key].to(
1059+
model_state_dict[key].dtype
1060+
)
1061+
model.load_state_dict(file_state_dict, strict=False)
10461062
return model
10471063
return super().from_pretrained(
10481064
model_path,

angelslim/models/omni/qwen3_omni.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,22 @@ def from_pretrained(
4747
device_map="auto",
4848
trust_remote_code=True,
4949
use_audio_in_video=False,
50+
attn_implementation="default",
5051
):
5152
self.use_audio_in_video = use_audio_in_video
52-
self.model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
53-
model_path,
54-
torch_dtype=torch_dtype,
55-
device_map=device_map,
56-
attn_implementation="flash_attention_2",
57-
)
53+
if attn_implementation == "default":
54+
self.model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
55+
model_path,
56+
torch_dtype=torch_dtype,
57+
device_map=device_map,
58+
)
59+
else:
60+
self.model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
61+
model_path,
62+
torch_dtype=torch_dtype,
63+
device_map=device_map,
64+
attn_implementation=attn_implementation,
65+
)
5866

5967
# Load tokenizer
6068
self.tokenizer = AutoTokenizer.from_pretrained(

angelslim/utils/config_parser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ class ModelConfig:
121121
low_cpu_mem_usage: Use low memory loading for large models
122122
use_cache: Whether to use cache during model loading
123123
cache_dir: Directory for caching model files
124+
use_audio_in_video: Whether to add audio track to a video file
125+
attn_implementation: The attention implementation to use in the model
124126
"""
125127

126128
name: str
@@ -132,6 +134,7 @@ class ModelConfig:
132134
use_cache: bool = field(default=False)
133135
cache_dir: Optional[str] = field(default=None)
134136
use_audio_in_video: bool = field(default=False)
137+
attn_implementation: str = field(default="default")
135138

136139

137140
@dataclass

configs/qwen3_omni/fp8_dynamic/qwen3_omni_fp8_dynamic.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ model:
1212
torch_dtype: auto
1313
device_map: auto
1414
use_audio_in_video: false
15+
attn_implementation: default
1516

1617
# Compression configuration
1718
compression:

configs/qwen3_omni/fp8_static/qwen3_omni_fp8_static.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ model:
1212
torch_dtype: auto
1313
device_map: auto
1414
use_audio_in_video: false
15+
attn_implementation: default
1516

1617
# Compression configuration
1718
compression:

docs/source/models/qwen3_omni/qwen3_omni_quant.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ FP8量化的配置文件可参考路径:`configs/qwen3_omni/fp8_static` 和 `c
1515
- `name`:模型名称,固定填写`Qwen_Omni`
1616
- `model_path`:可填写hugging face模型卡片名称或者本地路径。
1717
- `use_audio_in_video`: 用于控制是否使用源视频的音频轨道
18+
- `attn_implementation`: 模型中要使用的注意力实现,默认值为`default`,设为`flash_attention_2`可以降低GPU显存占用
1819

1920
#### compression配置
2021
- `name`:压缩策略类型,固定选择量化模式`PTQ`
@@ -28,6 +29,14 @@ FP8量化的配置文件可参考路径:`configs/qwen3_omni/fp8_static` 和 `c
2829

2930
### 启动量化流程
3031

32+
若在`model`配置中设置了`attn_implementation``flash_attention_2`,需要另外安装`FlashAttention 2`
33+
```shell
34+
pip install -U flash-attn --no-build-isolation
35+
36+
# ldd --version 如果 < 2.32,可降到 2.7.4.post1 以下版本
37+
pip install flash-attn==2.7.4.post1 --no-build-isolation
38+
```
39+
3140
通过以下命令启动FP8量化校准:
3241

3342
```shell

tools/run.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def multi_nodes_run(config):
9191
use_cache=model_config.use_cache,
9292
cache_dir=model_config.cache_dir,
9393
use_audio_in_video=model_config.use_audio_in_video,
94+
attn_implementation=model_config.attn_implementation,
9495
deploy_backend=global_config.deploy_backend,
9596
using_multi_nodes=True,
9697
)
@@ -151,6 +152,7 @@ def run(config):
151152
use_cache=model_config.use_cache,
152153
cache_dir=model_config.cache_dir,
153154
use_audio_in_video=model_config.use_audio_in_video,
155+
attn_implementation=model_config.attn_implementation,
154156
deploy_backend=global_config.deploy_backend,
155157
)
156158

0 commit comments

Comments
 (0)