|
| 1 | +import tqdm |
| 2 | +from typing import List, Tuple |
| 3 | +from .base import BaseAWQForCausalLM |
| 4 | +from transformers.models.qwen3.modeling_qwen3 import ( |
| 5 | + Qwen3DecoderLayer as OldQwen3DecoderLayer, |
| 6 | + Qwen3ForCausalLM as OldQwen3ForCausalLM, |
| 7 | +) |
| 8 | +from awq.utils.fused_utils import fuse_qkv |
| 9 | +from awq.modules.fused.block import QwenBlock |
| 10 | +from awq.modules.fused.model import LlamaLikeModel |
| 11 | +from awq.modules.fused.norm import FasterTransformerRMSNorm |
| 12 | + |
| 13 | + |
| 14 | +class Qwen3AWQForCausalLM(BaseAWQForCausalLM): |
| 15 | + layer_type = "Qwen3DecoderLayer" |
| 16 | + max_seq_len_key = "max_position_embeddings" |
| 17 | + |
| 18 | + @staticmethod |
| 19 | + def fuse_layers(model: OldQwen3ForCausalLM): |
| 20 | + fuser = Qwen3Fuser(model) |
| 21 | + fuser.fuse_transformer() |
| 22 | + |
| 23 | + @staticmethod |
| 24 | + def get_model_layers(model: OldQwen3ForCausalLM): |
| 25 | + return model.model.layers |
| 26 | + |
| 27 | + @staticmethod |
| 28 | + def get_act_for_scaling(module: OldQwen3DecoderLayer): |
| 29 | + return dict(is_scalable=False) |
| 30 | + |
| 31 | + @staticmethod |
| 32 | + def move_embed(model: OldQwen3ForCausalLM, device: str): |
| 33 | + model.model.embed_tokens = model.model.embed_tokens.to(device) |
| 34 | + model.model.rotary_emb = model.model.rotary_emb.to(device) |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def get_layers_for_scaling(module: OldQwen3DecoderLayer, input_feat, module_kwargs): |
| 38 | + layers = [] |
| 39 | + |
| 40 | + # attention input |
| 41 | + layers.append( |
| 42 | + dict( |
| 43 | + prev_op=module.input_layernorm, |
| 44 | + layers=[ |
| 45 | + module.self_attn.q_proj, |
| 46 | + module.self_attn.k_proj, |
| 47 | + module.self_attn.v_proj, |
| 48 | + ], |
| 49 | + inp=input_feat["self_attn.q_proj"], |
| 50 | + module2inspect=module.self_attn, |
| 51 | + kwargs=module_kwargs, |
| 52 | + ) |
| 53 | + ) |
| 54 | + |
| 55 | + # attention out |
| 56 | + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 |
| 57 | + if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: |
| 58 | + layers.append( |
| 59 | + dict( |
| 60 | + prev_op=module.self_attn.v_proj, |
| 61 | + layers=[module.self_attn.o_proj], |
| 62 | + inp=input_feat["self_attn.o_proj"], |
| 63 | + ) |
| 64 | + ) |
| 65 | + |
| 66 | + # linear 1 |
| 67 | + layers.append( |
| 68 | + dict( |
| 69 | + prev_op=module.post_attention_layernorm, |
| 70 | + layers=[module.mlp.gate_proj, module.mlp.up_proj], |
| 71 | + inp=input_feat["mlp.gate_proj"], |
| 72 | + module2inspect=module.mlp, |
| 73 | + ) |
| 74 | + ) |
| 75 | + |
| 76 | + # linear 2 |
| 77 | + layers.append( |
| 78 | + dict( |
| 79 | + prev_op=module.mlp.up_proj, |
| 80 | + layers=[module.mlp.down_proj], |
| 81 | + inp=input_feat["mlp.down_proj"], |
| 82 | + ) |
| 83 | + ) |
| 84 | + |
| 85 | + return layers |
| 86 | + |
| 87 | +class Qwen3Fuser: |
| 88 | + def __init__(self, model: OldQwen3ForCausalLM): |
| 89 | + self.model = model |
| 90 | + |
| 91 | + self.qwen3_blocks: List[Tuple[str, OldQwen3DecoderLayer]] = [ |
| 92 | + (name, module) |
| 93 | + for name, module in self.model.named_modules() |
| 94 | + if "Qwen3DecoderLayer".lower() in module.__class__.__name__.lower() |
| 95 | + ] |
| 96 | + |
| 97 | + def fuse_transformer(self): |
| 98 | + blocks = [] |
| 99 | + |
| 100 | + module: OldQwen3DecoderLayer |
| 101 | + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): |
| 102 | + device = next(iter(module.state_dict().values())).device |
| 103 | + qkv = fuse_qkv( |
| 104 | + module, |
| 105 | + module.self_attn.q_proj, |
| 106 | + module.self_attn.k_proj, |
| 107 | + module.self_attn.v_proj, |
| 108 | + ) |
| 109 | + norm_1 = FasterTransformerRMSNorm( |
| 110 | + module.input_layernorm.weight, module.input_layernorm.variance_epsilon |
| 111 | + ) |
| 112 | + norm_2 = FasterTransformerRMSNorm( |
| 113 | + module.post_attention_layernorm.weight, |
| 114 | + module.post_attention_layernorm.variance_epsilon, |
| 115 | + ) |
| 116 | + blocks.append( |
| 117 | + QwenBlock( |
| 118 | + hidden_size=self.model.config.hidden_size, |
| 119 | + n_heads=self.model.config.num_attention_heads, |
| 120 | + n_kv_heads=self.model.config.num_key_value_heads, |
| 121 | + qkv_layer=qkv, |
| 122 | + o_proj=module.self_attn.o_proj, |
| 123 | + mlp=module.mlp, |
| 124 | + norm_1=norm_1, |
| 125 | + norm_2=norm_2, |
| 126 | + dev=device, |
| 127 | + max_seq_len=self.model.config.max_seq_len, |
| 128 | + rope_theta=self.model.config.rope_theta, |
| 129 | + q_norm=module.self_attn.q_norm, |
| 130 | + k_norm=module.self_attn.k_norm, |
| 131 | + head_dim=self.model.config.head_dim, |
| 132 | + ) |
| 133 | + ) |
| 134 | + |
| 135 | + self.model.model = LlamaLikeModel( |
| 136 | + self.model.config.vocab_size, |
| 137 | + blocks, |
| 138 | + self.model.model.embed_tokens, |
| 139 | + self.model.model.norm, |
| 140 | + ) |
| 141 | + setattr(self.model.model, "blocks", self.model.model.blocks) |
| 142 | + |
0 commit comments