Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ def _bind_rotary_emb_fwd(self):
def _get_qkv(
self, input, infer_state: InferStateInfo, layer_weight
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
cache_kv = torch.mm(
input.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
cache_kv = layer_weight.kv_proj.mm(input.view(-1, self.embed_dim_)).view(
-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_
)

if self.use_qk_norm_:
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from .base_weight import BaseWeight
from .mm_weight import (
MMWeightPack,
MMWeightTpl,
MultiMMWeightTpl,
ROWMMWeight,
COLMMWeight,
MultiROWMMWeight,
ROWBMMWeight,
)
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
from .fused_moe_weight_tp import FusedMoeWeightTP
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
from .fused_moe_weight_ep import FusedMoeWeightEP
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def load_hf_weights(self, weights):
pass

@abstractmethod
def verify_load(self):
def verify_load(self) -> bool:
pass


Expand All @@ -24,30 +24,8 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: to
self.device_id_ = get_current_device_id()
self.data_type_ = data_type

def _slice_weight(self, weight: torch.Tensor):
# slice weight
return weight.to(self.data_type_)

def _slice_bias(self, bias: torch.Tensor):
# slice bias
return bias.to(self.data_type_)

def _slice_weight_scale(self, weight_scale: torch.Tensor):
# slice weight scale and zero point
return weight_scale

def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
# load weight
pass

def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
# load quantization scale
pass

def load_hf_weights(self, weights):
self._load_weights(weights)
self._load_scales(weights)
return
raise NotImplementedError("load_hf_weights must implement this method")

def verify_load(self):
pass
def verify_load(self) -> bool:
raise NotImplementedError("verify_load must implement this method")
Loading