Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

Commit f14d0fd

Browse files
Added Qwen 3 support. (#751)
1 parent a0ed8cb commit f14d0fd

File tree

9 files changed

+335
-0
lines changed

9 files changed

+335
-0
lines changed

awq/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from .llava import LlavaAWQForCausalLM
1616
from .mixtral import MixtralAWQForCausalLM
1717
from .qwen2 import Qwen2AWQForCausalLM
18+
from .qwen3 import Qwen3AWQForCausalLM
19+
from .qwen3_moe import Qwen3MoeAWQForCausalLM
1820
from .gemma import GemmaAWQForCausalLM
1921
from .gemma2 import Gemma2AWQForCausalLM
2022
from .stablelm import StableLmAWQForCausalLM

awq/models/auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
"baichuan": BaichuanAWQForCausalLM,
2727
"llava": LlavaAWQForCausalLM,
2828
"qwen2": Qwen2AWQForCausalLM,
29+
"qwen3": Qwen3AWQForCausalLM,
30+
"qwen3_moe": Qwen3MoeAWQForCausalLM,
2931
"gemma": GemmaAWQForCausalLM,
3032
"gemma2": Gemma2AWQForCausalLM,
3133
"stablelm": StableLmAWQForCausalLM,

awq/models/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
"llava": "AutoModelForVision2Seq",
7474
"qwen2": "AutoModelForCausalLM",
7575
"qwen2_vl": "AutoModelForVision2Seq",
76+
"qwen3": "AutoModelForCausalLM",
77+
"qwen3_moe": "AutoModelForCausalLM",
7678
"gemma": "AutoModelForCausalLM",
7779
"gemma2": "AutoModelForCausalLM",
7880
"stablelm": "AutoModelForCausalLM",

awq/models/qwen3.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import tqdm
2+
from typing import List, Tuple
3+
from .base import BaseAWQForCausalLM
4+
from transformers.models.qwen3.modeling_qwen3 import (
5+
Qwen3DecoderLayer as OldQwen3DecoderLayer,
6+
Qwen3ForCausalLM as OldQwen3ForCausalLM,
7+
)
8+
from awq.utils.fused_utils import fuse_qkv
9+
from awq.modules.fused.block import QwenBlock
10+
from awq.modules.fused.model import LlamaLikeModel
11+
from awq.modules.fused.norm import FasterTransformerRMSNorm
12+
13+
14+
class Qwen3AWQForCausalLM(BaseAWQForCausalLM):
15+
layer_type = "Qwen3DecoderLayer"
16+
max_seq_len_key = "max_position_embeddings"
17+
18+
@staticmethod
19+
def fuse_layers(model: OldQwen3ForCausalLM):
20+
fuser = Qwen3Fuser(model)
21+
fuser.fuse_transformer()
22+
23+
@staticmethod
24+
def get_model_layers(model: OldQwen3ForCausalLM):
25+
return model.model.layers
26+
27+
@staticmethod
28+
def get_act_for_scaling(module: OldQwen3DecoderLayer):
29+
return dict(is_scalable=False)
30+
31+
@staticmethod
32+
def move_embed(model: OldQwen3ForCausalLM, device: str):
33+
model.model.embed_tokens = model.model.embed_tokens.to(device)
34+
model.model.rotary_emb = model.model.rotary_emb.to(device)
35+
36+
@staticmethod
37+
def get_layers_for_scaling(module: OldQwen3DecoderLayer, input_feat, module_kwargs):
38+
layers = []
39+
40+
# attention input
41+
layers.append(
42+
dict(
43+
prev_op=module.input_layernorm,
44+
layers=[
45+
module.self_attn.q_proj,
46+
module.self_attn.k_proj,
47+
module.self_attn.v_proj,
48+
],
49+
inp=input_feat["self_attn.q_proj"],
50+
module2inspect=module.self_attn,
51+
kwargs=module_kwargs,
52+
)
53+
)
54+
55+
# attention out
56+
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
57+
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
58+
layers.append(
59+
dict(
60+
prev_op=module.self_attn.v_proj,
61+
layers=[module.self_attn.o_proj],
62+
inp=input_feat["self_attn.o_proj"],
63+
)
64+
)
65+
66+
# linear 1
67+
layers.append(
68+
dict(
69+
prev_op=module.post_attention_layernorm,
70+
layers=[module.mlp.gate_proj, module.mlp.up_proj],
71+
inp=input_feat["mlp.gate_proj"],
72+
module2inspect=module.mlp,
73+
)
74+
)
75+
76+
# linear 2
77+
layers.append(
78+
dict(
79+
prev_op=module.mlp.up_proj,
80+
layers=[module.mlp.down_proj],
81+
inp=input_feat["mlp.down_proj"],
82+
)
83+
)
84+
85+
return layers
86+
87+
class Qwen3Fuser:
88+
def __init__(self, model: OldQwen3ForCausalLM):
89+
self.model = model
90+
91+
self.qwen3_blocks: List[Tuple[str, OldQwen3DecoderLayer]] = [
92+
(name, module)
93+
for name, module in self.model.named_modules()
94+
if "Qwen3DecoderLayer".lower() in module.__class__.__name__.lower()
95+
]
96+
97+
def fuse_transformer(self):
98+
blocks = []
99+
100+
module: OldQwen3DecoderLayer
101+
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
102+
device = next(iter(module.state_dict().values())).device
103+
qkv = fuse_qkv(
104+
module,
105+
module.self_attn.q_proj,
106+
module.self_attn.k_proj,
107+
module.self_attn.v_proj,
108+
)
109+
norm_1 = FasterTransformerRMSNorm(
110+
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
111+
)
112+
norm_2 = FasterTransformerRMSNorm(
113+
module.post_attention_layernorm.weight,
114+
module.post_attention_layernorm.variance_epsilon,
115+
)
116+
blocks.append(
117+
QwenBlock(
118+
hidden_size=self.model.config.hidden_size,
119+
n_heads=self.model.config.num_attention_heads,
120+
n_kv_heads=self.model.config.num_key_value_heads,
121+
qkv_layer=qkv,
122+
o_proj=module.self_attn.o_proj,
123+
mlp=module.mlp,
124+
norm_1=norm_1,
125+
norm_2=norm_2,
126+
dev=device,
127+
max_seq_len=self.model.config.max_seq_len,
128+
rope_theta=self.model.config.rope_theta,
129+
q_norm=module.self_attn.q_norm,
130+
k_norm=module.self_attn.k_norm,
131+
head_dim=self.model.config.head_dim,
132+
)
133+
)
134+
135+
self.model.model = LlamaLikeModel(
136+
self.model.config.vocab_size,
137+
blocks,
138+
self.model.model.embed_tokens,
139+
self.model.model.norm,
140+
)
141+
setattr(self.model.model, "blocks", self.model.model.blocks)
142+

awq/models/qwen3_moe.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import tqdm
2+
from typing import List, Tuple
3+
from .base import BaseAWQForCausalLM
4+
5+
6+
class Qwen3MoeAWQForCausalLM(BaseAWQForCausalLM):
7+
layer_type = "Qwen3MoeDecoderLayer"
8+
max_seq_len_key = "max_position_embeddings"
9+
10+
@staticmethod
11+
def get_model_layers(model):
12+
return model.model.layers
13+
14+
@staticmethod
15+
def get_act_for_scaling(module):
16+
return dict(is_scalable=False)
17+
18+
@staticmethod
19+
def move_embed(model, device: str):
20+
model.model.embed_tokens = model.model.embed_tokens.to(device)
21+
model.model.rotary_emb = model.model.rotary_emb.to(device)
22+
23+
@staticmethod
24+
def get_layers_for_scaling(module, input_feat, module_kwargs):
25+
layers = []
26+
27+
# attention input
28+
layers.append(
29+
dict(
30+
prev_op=module.input_layernorm,
31+
layers=[
32+
module.self_attn.q_proj,
33+
module.self_attn.k_proj,
34+
module.self_attn.v_proj,
35+
],
36+
inp=input_feat["self_attn.q_proj"],
37+
module2inspect=module.self_attn,
38+
kwargs=module_kwargs,
39+
)
40+
)
41+
42+
# attention out
43+
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
44+
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
45+
layers.append(
46+
dict(
47+
prev_op=module.self_attn.v_proj,
48+
layers=[module.self_attn.o_proj],
49+
inp=input_feat["self_attn.o_proj"],
50+
)
51+
)
52+
53+
if hasattr(module.mlp, "gate"):
54+
# linear in
55+
layers.append(
56+
dict(
57+
prev_op=module.post_attention_layernorm,
58+
layers=[
59+
w
60+
for expert in module.mlp.experts
61+
for w in [expert.gate_proj, expert.up_proj]
62+
],
63+
inp=input_feat["mlp"],
64+
module2inspect=module.mlp,
65+
)
66+
)
67+
68+
# linear out
69+
for i, expert in enumerate(module.mlp.experts):
70+
layers.append(
71+
dict(
72+
prev_op=expert.up_proj,
73+
layers=[expert.down_proj],
74+
inp=input_feat[f"mlp.experts.{i}.down_proj"],
75+
)
76+
)
77+
78+
else:
79+
# linear 1
80+
layers.append(
81+
dict(
82+
prev_op=module.post_attention_layernorm,
83+
layers=[module.mlp.gate_proj, module.mlp.up_proj],
84+
inp=input_feat["mlp.gate_proj"],
85+
module2inspect=module.mlp,
86+
)
87+
)
88+
89+
# linear 2
90+
layers.append(
91+
dict(
92+
prev_op=module.mlp.up_proj,
93+
layers=[module.mlp.down_proj],
94+
inp=input_feat["mlp.down_proj"],
95+
)
96+
)
97+
98+
return layers
99+
100+

awq/modules/fused/attn.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ def __init__(
140140
partial_rotary_factor=1.0,
141141
head_dim=None,
142142
attn_logit_softcapping=0.0,
143+
q_norm=None,
144+
k_norm=None,
143145
**kwargs
144146
):
145147
super().__init__()
@@ -154,6 +156,8 @@ def __init__(
154156

155157
self.qkv_proj = qkv_layer
156158
self.o_proj = o_proj
159+
self.q_norm = q_norm
160+
self.k_norm = k_norm
157161
self.start_pos = 0
158162
self.use_alibi = use_alibi
159163
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
@@ -243,6 +247,11 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs):
243247
xk = self.attention_shapes["xk_slice"](xqkv)
244248
xv = self.attention_shapes["xv_slice"](xqkv)
245249

250+
if self.q_norm is not None:
251+
xq = self.q_norm(xq)
252+
if self.k_norm is not None:
253+
xk = self.k_norm(xk)
254+
246255
if not self.use_alibi:
247256
xq, xk = self.rope.forward(
248257
xq, xk, self.start_pos, seqlen, partial=self.partial_rotary_factor < 1

awq/modules/fused/block.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,74 @@ def forward(
119119

120120
return out
121121

122+
class QwenBlock(nn.Module):
123+
"""
124+
QwenBlock is intended to be reused across blocks that have
125+
an architecture that closely resembles Qwen2/Qwen3, e.g. use q_norm and k_norm.
126+
"""
127+
128+
def __init__(
129+
self,
130+
hidden_size,
131+
n_heads,
132+
n_kv_heads,
133+
qkv_layer,
134+
o_proj,
135+
mlp,
136+
norm_1,
137+
norm_2,
138+
dev,
139+
max_seq_len,
140+
rope_theta=10000,
141+
partial_rotary_factor=1.0,
142+
use_alibi=False,
143+
head_dim=None,
144+
q_norm=None,
145+
k_norm=None,
146+
):
147+
super().__init__()
148+
self.n_heads = n_heads
149+
self.n_kv_heads = n_kv_heads
150+
self.head_dim = hidden_size // n_heads
151+
152+
# To support qwen3, its head_dim is separate
153+
if head_dim:
154+
self.head_dim = head_dim
155+
156+
self.hidden_size = hidden_size
157+
self.norm_1 = norm_1.to(dev)
158+
self.attn = QuantAttentionFused(
159+
self.hidden_size,
160+
self.n_heads,
161+
self.n_kv_heads,
162+
qkv_layer,
163+
o_proj,
164+
dev=dev,
165+
max_seq_len=max_seq_len,
166+
use_alibi=use_alibi,
167+
rope_theta=rope_theta,
168+
partial_rotary_factor=partial_rotary_factor,
169+
head_dim=head_dim,
170+
q_norm=q_norm,
171+
k_norm=k_norm,
172+
).to(dev)
173+
self.norm_2 = norm_2.to(dev)
174+
self.mlp = mlp.to(dev)
175+
self.device = dev
176+
177+
def forward(
178+
self,
179+
hidden_states,
180+
):
181+
norm_out = self.norm_1(hidden_states)
182+
attn_output, _, _ = self.attn.forward(
183+
hidden_states=norm_out,
184+
)
185+
186+
h = hidden_states.to(attn_output.device) + attn_output
187+
out = h + self.mlp.forward(self.norm_2(h))
188+
189+
return out
122190

123191
class Gemma2LikeBlock(nn.Module):
124192
def __init__(

awq/modules/linear/gemm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def forward(
4141

4242
out_shape = x.shape[:-1] + (out_features,)
4343
x = x.to(torch.float16)
44+
if x.shape[0] == 0:
45+
return torch.zeros(out_shape, dtype=x.dtype, device=x.device)
4446

4547
if awq_ext is not None:
4648
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024

0 commit comments

Comments
 (0)