Skip to content

Commit 5f8d736

Browse files
shihaobaiwangzaijun
andauthored
Awq support and mm refactor (#1084)
Co-authored-by: wangzaijun <[email protected]>
1 parent 98d385a commit 5f8d736

File tree

32 files changed

+1754
-577
lines changed

32 files changed

+1754
-577
lines changed

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@ def _bind_rotary_emb_fwd(self):
4444
def _get_qkv(
4545
self, input, infer_state: InferStateInfo, layer_weight
4646
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
47-
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
48-
cache_kv = torch.mm(
49-
input.view(-1, self.embed_dim_),
50-
layer_weight.kv_weight_,
51-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
47+
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
48+
cache_kv = layer_weight.kv_proj.mm(input.view(-1, self.embed_dim_)).view(
49+
-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_
50+
)
5251

5352
if self.use_qk_norm_:
5453
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from .base_weight import BaseWeight
22
from .mm_weight import (
3+
MMWeightPack,
34
MMWeightTpl,
4-
MultiMMWeightTpl,
55
ROWMMWeight,
66
COLMMWeight,
7-
MultiROWMMWeight,
87
ROWBMMWeight,
98
)
109
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
11-
from .fused_moe_weight_tp import FusedMoeWeightTP
10+
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
1211
from .fused_moe_weight_ep import FusedMoeWeightEP

lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def load_hf_weights(self, weights):
1313
pass
1414

1515
@abstractmethod
16-
def verify_load(self):
16+
def verify_load(self) -> bool:
1717
pass
1818

1919

@@ -24,30 +24,8 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: to
2424
self.device_id_ = get_current_device_id()
2525
self.data_type_ = data_type
2626

27-
def _slice_weight(self, weight: torch.Tensor):
28-
# slice weight
29-
return weight.to(self.data_type_)
30-
31-
def _slice_bias(self, bias: torch.Tensor):
32-
# slice bias
33-
return bias.to(self.data_type_)
34-
35-
def _slice_weight_scale(self, weight_scale: torch.Tensor):
36-
# slice weight scale and zero point
37-
return weight_scale
38-
39-
def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
40-
# load weight
41-
pass
42-
43-
def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
44-
# load quantization scale
45-
pass
46-
4727
def load_hf_weights(self, weights):
48-
self._load_weights(weights)
49-
self._load_scales(weights)
50-
return
28+
raise NotImplementedError("load_hf_weights must implement this method")
5129

52-
def verify_load(self):
53-
pass
30+
def verify_load(self) -> bool:
31+
raise NotImplementedError("verify_load must implement this method")

0 commit comments

Comments
 (0)