Skip to content

Commit 75e3985

Browse files
authored
Fix Break change of AWQ FusedModules due to Attention Refactor (#41909)
* fix awq bc due to attention refactor * feat: support more rope_types for awq fusion * feat: add test for llama3 * fix ruff format * propagate changes in modeling_llama
1 parent 61cafd9 commit 75e3985

File tree

23 files changed

+149
-5
lines changed

23 files changed

+149
-5
lines changed

src/transformers/integrations/awq.py

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from packaging import version
1919

2020
from ..activations import ACT2FN
21+
from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS
2122
from ..modeling_utils import PreTrainedModel
2223
from ..utils import is_auto_awq_available, is_ipex_available, is_torch_available, logging
2324
from ..utils.quantization_config import (
@@ -46,7 +47,6 @@
4647
"mlp": ["w1", "w3", "w2"],
4748
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
4849
"use_alibi": False,
49-
"rope_theta": 1000000.0,
5050
},
5151
"llama": {
5252
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
@@ -60,6 +60,18 @@
6060
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
6161
"use_alibi": False,
6262
},
63+
"qwen2": {
64+
"attention": ["q_proj", "k_proj", "v_proj", "o_proj"],
65+
"mlp": ["gate_proj", "up_proj", "down_proj"],
66+
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
67+
"use_alibi": False,
68+
},
69+
"qwen3": {
70+
"attention": ["q_proj", "k_proj", "v_proj", "o_proj", "q_norm", "k_norm"],
71+
"mlp": ["gate_proj", "up_proj", "down_proj"],
72+
"layernorm": ["input_layernorm", "post_attention_layernorm", "norm"],
73+
"use_alibi": False,
74+
},
6375
}
6476

6577
AWQ_SCALES_MAPPINGS = {
@@ -74,6 +86,53 @@
7486
}
7587

7688

89+
if is_auto_awq_available():
90+
from awq.modules.fused.attn import RoPE
91+
92+
class AWQRoPE(RoPE):
93+
"""
94+
AWQRoPE module for hacking rope implementation in AWQ fused attention modules to support more models.
95+
96+
Args:
97+
rope_type (`str`):
98+
The rope type to use.
99+
head_dim (`int`):
100+
The head dimension.
101+
max_seq_len (`int`):
102+
The maximum sequence length.
103+
config (`PreTrainedConfig`):
104+
The model config object.
105+
device (`torch.device`):
106+
The device to put the module on.
107+
"""
108+
109+
def __init__(self, rope_type, head_dim, max_seq_len, config, device):
110+
rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
111+
self.inv_freq, self.attention_scaling = rope_init_fn(config, device)
112+
# Use fake rope_theta to initialize the parent class
113+
super().__init__(head_dim=head_dim, max_seq_len=max_seq_len, device=device, rope_theta=-1)
114+
115+
def precompute_freqs_cis(self, dim: int, end: int, theta=-1):
116+
t = torch.arange(end, device=self.inv_freq.device)
117+
freqs = torch.outer(t, self.inv_freq).float()
118+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
119+
del self.inv_freq # free the memory
120+
return freqs_cis
121+
122+
def forward(
123+
self,
124+
xq: torch.Tensor,
125+
xk: torch.Tensor,
126+
start_pos: int,
127+
seqlen: int,
128+
partial: bool = False,
129+
):
130+
xq_out, xk_out = super().forward(xq, xk, start_pos, seqlen, partial)
131+
xq_out = (xq_out * self.attention_scaling).type_as(xq)
132+
xk_out = (xk_out * self.attention_scaling).type_as(xk)
133+
return xq_out, xk_out
134+
135+
77136
def replace_quantization_scales(model, model_type):
78137
from awq.modules.act import ScaledActivation
79138

@@ -219,15 +278,17 @@ def get_modules_to_fuse(model, quantization_config):
219278
# Properly deal with the case where we have a multi-modal model as well (e.g. Llava)
220279
config = model.config.get_text_config(decoder=True)
221280

222-
# Handle hidden_size, num_attention_heads, num_key_value_heads on our own.
281+
# Handle hidden_size, num_attention_heads, num_key_value_heads, rope_parameters on our own.
223282
hidden_size = config.hidden_size
224283
num_attention_heads = config.num_attention_heads
225284
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
285+
rope_parameters = config.rope_parameters
226286

227287
# Fill `current_fused_mapping` with the expected values
228288
current_fused_mapping["hidden_size"] = hidden_size
229289
current_fused_mapping["num_attention_heads"] = num_attention_heads
230290
current_fused_mapping["num_key_value_heads"] = num_key_value_heads
291+
current_fused_mapping["rope_parameters"] = rope_parameters
231292
current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len
232293
else:
233294
raise ValueError(
@@ -261,6 +322,15 @@ def fuse_awq_modules(model, quantization_config):
261322
from awq.modules.fused.attn import QuantAttentionFused
262323
from awq.modules.fused.mlp import QuantFusedMLP
263324
from awq.modules.fused.norm import FasterTransformerRMSNorm
325+
326+
# Hack QuantAttentionFused to modify the return value of forward function to avoid returning past_key_value
327+
old_quant_attention_fused_forward = QuantAttentionFused.forward
328+
329+
def new_quant_attention_fused_forward(self, *args, **kwargs):
330+
attn_output, attention_weight, _ = old_quant_attention_fused_forward(self, *args, **kwargs)
331+
return attn_output, attention_weight
332+
333+
QuantAttentionFused.forward = new_quant_attention_fused_forward
264334
else:
265335
raise ValueError("Fusing is only supported for the AutoAWQ backend")
266336

@@ -376,7 +446,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
376446
The pytorch parent module that has layernorm modules to fuse
377447
modules_to_fuse (`list[str]`):
378448
The module fusing mapping. The dictionary has to contain a field `attention` with attention module names
379-
in the correct order: q, k, v, o layer
449+
in the correct order: q, k, v, o layer, (q_norm, k_norm) optional
380450
current_module_name (`str`):
381451
The current submodule name
382452
target_cls (`~autoawq.QuantAttentionFused`):
@@ -415,6 +485,14 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
415485
v_proj = getattr(module, modules_to_fuse["attention"][2])
416486
o_proj = getattr(module, modules_to_fuse["attention"][3])
417487

488+
# maybe there are q_norm and k_norm layers
489+
if len(modules_to_fuse["attention"]) > 4:
490+
q_norm = getattr(module, modules_to_fuse["attention"][4])
491+
k_norm = getattr(module, modules_to_fuse["attention"][5])
492+
else:
493+
q_norm = None
494+
k_norm = None
495+
418496
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
419497

420498
qkv_layer = linear_target_cls(
@@ -445,16 +523,30 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
445523
modules_to_fuse["max_seq_len"],
446524
use_alibi=modules_to_fuse["use_alibi"],
447525
# The default value in autoawq is set to 10000.0
448-
rope_theta=modules_to_fuse.get("rope_theta", 10000.0),
526+
rope_theta=modules_to_fuse["rope_parameters"].get("rope_theta", 10000.0),
527+
q_norm=q_norm,
528+
k_norm=k_norm,
449529
)
450530

531+
# Hack the rope module if not using alibi and rope_type is not default
532+
# As the default rope implementation in autoawq only supports the "default" rope type
533+
rope_type = modules_to_fuse["rope_parameters"].get("rope_type", "default")
534+
if not modules_to_fuse["use_alibi"] and rope_type != "default":
535+
fused_attention_layer.rope = AWQRoPE(
536+
rope_type,
537+
modules_to_fuse["hidden_size"] // modules_to_fuse["num_attention_heads"],
538+
modules_to_fuse["max_seq_len"],
539+
model.config.get_text_config(decoder=True),
540+
previous_device,
541+
)
542+
451543
fused_attention_layer.is_hf_transformers = True
452544

453545
parent_name, child_name = current_module_name.rsplit(".", 1)
454546
parent = model.get_submodule(parent_name)
455547
setattr(parent, child_name, fused_attention_layer.to(previous_device))
456548

457-
del q_proj, k_proj, v_proj, o_proj
549+
del q_proj, k_proj, v_proj, o_proj, q_norm, k_norm
458550
module_has_been_fused = True
459551

460552
return module_has_been_fused

src/transformers/models/apertus/modeling_apertus.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def forward(
416416
position_embeddings=position_embeddings,
417417
position_ids=position_ids,
418418
past_key_values=past_key_values,
419+
use_cache=use_cache,
419420
cache_position=cache_position,
420421
**kwargs,
421422
)

src/transformers/models/arcee/modeling_arcee.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def forward(
421421
position_embeddings=position_embeddings,
422422
position_ids=position_ids,
423423
past_key_values=past_key_values,
424+
use_cache=use_cache,
424425
cache_position=cache_position,
425426
**kwargs,
426427
)

src/transformers/models/aria/modeling_aria.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ def forward(
750750
position_embeddings=position_embeddings,
751751
position_ids=position_ids,
752752
past_key_values=past_key_values,
753+
use_cache=use_cache,
753754
cache_position=cache_position,
754755
**kwargs,
755756
)

src/transformers/models/bitnet/modeling_bitnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def forward(
420420
position_embeddings=position_embeddings,
421421
position_ids=position_ids,
422422
past_key_values=past_key_values,
423+
use_cache=use_cache,
423424
cache_position=cache_position,
424425
**kwargs,
425426
)

src/transformers/models/cohere/modeling_cohere.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ def forward(
453453
position_embeddings=position_embeddings,
454454
position_ids=position_ids,
455455
past_key_values=past_key_values,
456+
use_cache=use_cache,
456457
cache_position=cache_position,
457458
**kwargs,
458459
)

src/transformers/models/csm/modeling_csm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,7 @@ def forward(
754754
position_embeddings=position_embeddings,
755755
position_ids=position_ids,
756756
past_key_values=past_key_values,
757+
use_cache=use_cache,
757758
cache_position=cache_position,
758759
**kwargs,
759760
)

src/transformers/models/deepseek_v2/modeling_deepseek_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def forward(
537537
position_embeddings=position_embeddings,
538538
position_ids=position_ids,
539539
past_key_values=past_key_values,
540+
use_cache=use_cache,
540541
cache_position=cache_position,
541542
**kwargs,
542543
)

src/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def forward(
626626
position_embeddings=position_embeddings,
627627
position_ids=position_ids,
628628
past_key_values=past_key_values,
629+
use_cache=use_cache,
629630
cache_position=cache_position,
630631
**kwargs,
631632
)

src/transformers/models/diffllama/modeling_diffllama.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ def forward(
676676
position_embeddings=position_embeddings,
677677
position_ids=position_ids,
678678
past_key_values=past_key_values,
679+
use_cache=use_cache,
679680
cache_position=cache_position,
680681
**kwargs,
681682
)

0 commit comments

Comments
 (0)