|
18 | 18 | from packaging import version |
19 | 19 |
|
20 | 20 | from ..activations import ACT2FN |
| 21 | +from ..modeling_rope_utils import ROPE_INIT_FUNCTIONS |
21 | 22 | from ..modeling_utils import PreTrainedModel |
22 | 23 | from ..utils import is_auto_awq_available, is_ipex_available, is_torch_available, logging |
23 | 24 | from ..utils.quantization_config import ( |
|
46 | 47 | "mlp": ["w1", "w3", "w2"], |
47 | 48 | "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], |
48 | 49 | "use_alibi": False, |
49 | | - "rope_theta": 1000000.0, |
50 | 50 | }, |
51 | 51 | "llama": { |
52 | 52 | "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], |
|
60 | 60 | "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], |
61 | 61 | "use_alibi": False, |
62 | 62 | }, |
| 63 | + "qwen2": { |
| 64 | + "attention": ["q_proj", "k_proj", "v_proj", "o_proj"], |
| 65 | + "mlp": ["gate_proj", "up_proj", "down_proj"], |
| 66 | + "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], |
| 67 | + "use_alibi": False, |
| 68 | + }, |
| 69 | + "qwen3": { |
| 70 | + "attention": ["q_proj", "k_proj", "v_proj", "o_proj", "q_norm", "k_norm"], |
| 71 | + "mlp": ["gate_proj", "up_proj", "down_proj"], |
| 72 | + "layernorm": ["input_layernorm", "post_attention_layernorm", "norm"], |
| 73 | + "use_alibi": False, |
| 74 | + }, |
63 | 75 | } |
64 | 76 |
|
65 | 77 | AWQ_SCALES_MAPPINGS = { |
|
74 | 86 | } |
75 | 87 |
|
76 | 88 |
|
| 89 | +if is_auto_awq_available(): |
| 90 | + from awq.modules.fused.attn import RoPE |
| 91 | + |
| 92 | + class AWQRoPE(RoPE): |
| 93 | + """ |
| 94 | + AWQRoPE module for hacking rope implementation in AWQ fused attention modules to support more models. |
| 95 | +
|
| 96 | + Args: |
| 97 | + rope_type (`str`): |
| 98 | + The rope type to use. |
| 99 | + head_dim (`int`): |
| 100 | + The head dimension. |
| 101 | + max_seq_len (`int`): |
| 102 | + The maximum sequence length. |
| 103 | + config (`PreTrainedConfig`): |
| 104 | + The model config object. |
| 105 | + device (`torch.device`): |
| 106 | + The device to put the module on. |
| 107 | + """ |
| 108 | + |
| 109 | + def __init__(self, rope_type, head_dim, max_seq_len, config, device): |
| 110 | + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] |
| 111 | + self.inv_freq, self.attention_scaling = rope_init_fn(config, device) |
| 112 | + # Use fake rope_theta to initialize the parent class |
| 113 | + super().__init__(head_dim=head_dim, max_seq_len=max_seq_len, device=device, rope_theta=-1) |
| 114 | + |
| 115 | + def precompute_freqs_cis(self, dim: int, end: int, theta=-1): |
| 116 | + t = torch.arange(end, device=self.inv_freq.device) |
| 117 | + freqs = torch.outer(t, self.inv_freq).float() |
| 118 | + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| 119 | + del self.inv_freq # free the memory |
| 120 | + return freqs_cis |
| 121 | + |
| 122 | + def forward( |
| 123 | + self, |
| 124 | + xq: torch.Tensor, |
| 125 | + xk: torch.Tensor, |
| 126 | + start_pos: int, |
| 127 | + seqlen: int, |
| 128 | + partial: bool = False, |
| 129 | + ): |
| 130 | + xq_out, xk_out = super().forward(xq, xk, start_pos, seqlen, partial) |
| 131 | + xq_out = (xq_out * self.attention_scaling).type_as(xq) |
| 132 | + xk_out = (xk_out * self.attention_scaling).type_as(xk) |
| 133 | + return xq_out, xk_out |
| 134 | + |
| 135 | + |
77 | 136 | def replace_quantization_scales(model, model_type): |
78 | 137 | from awq.modules.act import ScaledActivation |
79 | 138 |
|
@@ -219,15 +278,17 @@ def get_modules_to_fuse(model, quantization_config): |
219 | 278 | # Properly deal with the case where we have a multi-modal model as well (e.g. Llava) |
220 | 279 | config = model.config.get_text_config(decoder=True) |
221 | 280 |
|
222 | | - # Handle hidden_size, num_attention_heads, num_key_value_heads on our own. |
| 281 | + # Handle hidden_size, num_attention_heads, num_key_value_heads, rope_parameters on our own. |
223 | 282 | hidden_size = config.hidden_size |
224 | 283 | num_attention_heads = config.num_attention_heads |
225 | 284 | num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads) |
| 285 | + rope_parameters = config.rope_parameters |
226 | 286 |
|
227 | 287 | # Fill `current_fused_mapping` with the expected values |
228 | 288 | current_fused_mapping["hidden_size"] = hidden_size |
229 | 289 | current_fused_mapping["num_attention_heads"] = num_attention_heads |
230 | 290 | current_fused_mapping["num_key_value_heads"] = num_key_value_heads |
| 291 | + current_fused_mapping["rope_parameters"] = rope_parameters |
231 | 292 | current_fused_mapping["max_seq_len"] = quantization_config.fuse_max_seq_len |
232 | 293 | else: |
233 | 294 | raise ValueError( |
@@ -261,6 +322,15 @@ def fuse_awq_modules(model, quantization_config): |
261 | 322 | from awq.modules.fused.attn import QuantAttentionFused |
262 | 323 | from awq.modules.fused.mlp import QuantFusedMLP |
263 | 324 | from awq.modules.fused.norm import FasterTransformerRMSNorm |
| 325 | + |
| 326 | + # Hack QuantAttentionFused to modify the return value of forward function to avoid returning past_key_value |
| 327 | + old_quant_attention_fused_forward = QuantAttentionFused.forward |
| 328 | + |
| 329 | + def new_quant_attention_fused_forward(self, *args, **kwargs): |
| 330 | + attn_output, attention_weight, _ = old_quant_attention_fused_forward(self, *args, **kwargs) |
| 331 | + return attn_output, attention_weight |
| 332 | + |
| 333 | + QuantAttentionFused.forward = new_quant_attention_fused_forward |
264 | 334 | else: |
265 | 335 | raise ValueError("Fusing is only supported for the AutoAWQ backend") |
266 | 336 |
|
@@ -376,7 +446,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na |
376 | 446 | The pytorch parent module that has layernorm modules to fuse |
377 | 447 | modules_to_fuse (`list[str]`): |
378 | 448 | The module fusing mapping. The dictionary has to contain a field `attention` with attention module names |
379 | | - in the correct order: q, k, v, o layer |
| 449 | + in the correct order: q, k, v, o layer, (q_norm, k_norm) optional |
380 | 450 | current_module_name (`str`): |
381 | 451 | The current submodule name |
382 | 452 | target_cls (`~autoawq.QuantAttentionFused`): |
@@ -415,6 +485,14 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na |
415 | 485 | v_proj = getattr(module, modules_to_fuse["attention"][2]) |
416 | 486 | o_proj = getattr(module, modules_to_fuse["attention"][3]) |
417 | 487 |
|
| 488 | + # maybe there are q_norm and k_norm layers |
| 489 | + if len(modules_to_fuse["attention"]) > 4: |
| 490 | + q_norm = getattr(module, modules_to_fuse["attention"][4]) |
| 491 | + k_norm = getattr(module, modules_to_fuse["attention"][5]) |
| 492 | + else: |
| 493 | + q_norm = None |
| 494 | + k_norm = None |
| 495 | + |
418 | 496 | bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None |
419 | 497 |
|
420 | 498 | qkv_layer = linear_target_cls( |
@@ -445,16 +523,30 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na |
445 | 523 | modules_to_fuse["max_seq_len"], |
446 | 524 | use_alibi=modules_to_fuse["use_alibi"], |
447 | 525 | # The default value in autoawq is set to 10000.0 |
448 | | - rope_theta=modules_to_fuse.get("rope_theta", 10000.0), |
| 526 | + rope_theta=modules_to_fuse["rope_parameters"].get("rope_theta", 10000.0), |
| 527 | + q_norm=q_norm, |
| 528 | + k_norm=k_norm, |
449 | 529 | ) |
450 | 530 |
|
| 531 | + # Hack the rope module if not using alibi and rope_type is not default |
| 532 | + # As the default rope implementation in autoawq only supports the "default" rope type |
| 533 | + rope_type = modules_to_fuse["rope_parameters"].get("rope_type", "default") |
| 534 | + if not modules_to_fuse["use_alibi"] and rope_type != "default": |
| 535 | + fused_attention_layer.rope = AWQRoPE( |
| 536 | + rope_type, |
| 537 | + modules_to_fuse["hidden_size"] // modules_to_fuse["num_attention_heads"], |
| 538 | + modules_to_fuse["max_seq_len"], |
| 539 | + model.config.get_text_config(decoder=True), |
| 540 | + previous_device, |
| 541 | + ) |
| 542 | + |
451 | 543 | fused_attention_layer.is_hf_transformers = True |
452 | 544 |
|
453 | 545 | parent_name, child_name = current_module_name.rsplit(".", 1) |
454 | 546 | parent = model.get_submodule(parent_name) |
455 | 547 | setattr(parent, child_name, fused_attention_layer.to(previous_device)) |
456 | 548 |
|
457 | | - del q_proj, k_proj, v_proj, o_proj |
| 549 | + del q_proj, k_proj, v_proj, o_proj, q_norm, k_norm |
458 | 550 | module_has_been_fused = True |
459 | 551 |
|
460 | 552 | return module_has_been_fused |
|
0 commit comments