Skip to content

Commit 36862fe

Browse files
committed
refactor mm_weight
1 parent 684f4c0 commit 36862fe

File tree

10 files changed

+750
-561
lines changed

10 files changed

+750
-561
lines changed
Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from multiprocessing import parent_process
12
import torch
23
from abc import ABC, abstractmethod
34
from typing import Dict
@@ -14,7 +15,7 @@ def load_hf_weights(self, weights):
1415

1516
@abstractmethod
1617
def verify_load(self):
17-
pass
18+
parent_process
1819

1920

2021
class BaseWeightTpl(BaseWeight):
@@ -24,35 +25,8 @@ def __init__(self, tp_rank: int = None, tp_world_size: int = None, data_type: to
2425
self.device_id_ = get_current_device_id()
2526
self.data_type_ = data_type
2627

27-
def _slice_weight(self, weight: torch.Tensor):
28-
# slice weight
29-
return weight.to(self.data_type_)
30-
31-
def _slice_bias(self, bias: torch.Tensor):
32-
# slice bias
33-
return bias.to(self.data_type_)
34-
35-
def _slice_weight_scale(self, weight_scale: torch.Tensor):
36-
# slice weight scale and zero point
37-
return weight_scale
38-
39-
def _load_weights(self, weights: Dict[str, torch.Tensor]) -> None:
40-
# load weight
41-
pass
42-
43-
def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
44-
# load quantization scale
45-
pass
46-
47-
def _load_zero_points(self, weights: Dict[str, torch.Tensor]) -> None:
48-
# load quantization zero points
49-
pass
50-
5128
def load_hf_weights(self, weights):
52-
self._load_weights(weights)
53-
self._load_scales(weights)
54-
self._load_zero_points(weights)
55-
return
29+
raise NotImplementedError("load_hf_weights must implement this method")
5630

5731
def verify_load(self):
58-
pass
32+
raise NotImplementedError("verify_load must implement this method")

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/__init__.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,10 @@
33
MultiMMWeightTpl,
44
AWQMultiMMWeightTpl,
55
)
6-
from .rowmm_weight import (
6+
from .mm_factory import (
7+
MMWeight,
78
ROWMMWeight,
8-
ROWBMMWeight,
99
MultiROWMMWeight,
10-
W8A8B128ROWMMWeight,
11-
W8A8B128ROWBMMWeight,
12-
W8A8B128MultiROWMMWeight,
13-
)
14-
from .colmm_weight import (
10+
ROWBMMWeight,
1511
COLMMWeight,
16-
W8A8B128COLMMWeight,
1712
)
Lines changed: 33 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,17 @@
11
import torch
22
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import (
3-
MMWeight,
4-
MMWeightTpl,
5-
generate_scale_name,
3+
SingleMMWeightTpl,
4+
DeepGemmFP8W8A8B128MMWeight,
65
AWQMMWeightTpl,
76
)
87
from lightllm.common.quantization import Quantcfg
98
from lightllm.utils.dist_utils import get_current_device_id
109
from lightllm.common.quantization.quantize_method import QuantizationMethod
1110
from typing import Dict, List, Optional
11+
from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin
1212

1313

14-
class COLMMWeight(MMWeight):
15-
@classmethod
16-
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
17-
if quant_method is None or not quantized_weight:
18-
return UnquantizedCOLMMWeight
19-
return COLBMM_WEIGHT_CLS_MAP[quant_method.method_name]
20-
21-
22-
class UnquantizedCOLMMWeight(MMWeightTpl):
14+
class UnquantizedCOLMMWeight(SingleMMWeightTpl):
2315
def __init__(
2416
self,
2517
weight_name: str,
@@ -29,24 +21,18 @@ def __init__(
2921
tp_rank: int = None,
3022
tp_world_size: int = None,
3123
) -> None:
32-
super().__init__(data_type, quant_method, tp_rank, tp_world_size)
33-
self.weight_name = weight_name
34-
self.bias_name = bias_name
35-
self.has_bias = bias_name is not None
36-
37-
def _slice_weight(self, tensor):
38-
assert tensor.shape[1] % self.tp_world_size_ == 0, f"tp slice error {tensor.shape[1]} % {self.tp_world_size_}"
39-
tp_size = tensor.shape[1] // self.tp_world_size_
40-
return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)].to(self.data_type_)
41-
42-
def _slice_bias(self, bias):
43-
"""
44-
因为 Colmm 列 tp 切分的计算,最后会有一个 reduce 操作,直接将 bias / tp_world_size 可以节省一步计算。
45-
"""
46-
return (bias / self.tp_world_size_).to(self.data_type_)
24+
super().__init__(
25+
weight_name=weight_name,
26+
data_type=data_type,
27+
bias_name=bias_name,
28+
quant_method=quant_method,
29+
tp_rank=tp_rank,
30+
tp_world_size=tp_world_size,
31+
)
32+
self.param_slicer = ColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
4733

4834

49-
class W8A8B128COLMMWeight(MMWeightTpl):
35+
class DeepGemmFP8W8A8B128COLMMWeight(DeepGemmFP8W8A8B128MMWeight):
5036
def __init__(
5137
self,
5238
weight_name: str,
@@ -56,47 +42,15 @@ def __init__(
5642
tp_rank: int = None,
5743
tp_world_size: int = None,
5844
) -> None:
59-
super().__init__(data_type, quant_method, tp_rank, tp_world_size)
60-
self.weight_name = weight_name
61-
self.bias_name = bias_name
62-
self.has_bias = bias_name is not None
63-
64-
self.weight_scale_name, self.act_scale_name = generate_scale_name(
65-
weight_name, quant_method.weight_scale_suffix, quant_method.act_scale_suffix
45+
super().__init__(
46+
weight_name=weight_name,
47+
data_type=data_type,
48+
bias_name=bias_name,
49+
quant_method=quant_method,
50+
tp_rank=tp_rank,
51+
tp_world_size=tp_world_size,
6652
)
67-
self.weight_scale: Optional[torch.Tensor] = None
68-
self.block_size = self.quant_method.block_size
69-
self.quantized_weight = True
70-
71-
def _slice_weight(self, tensor):
72-
assert tensor.shape[1] % self.tp_world_size_ == 0, f"tp slice error {tensor.shape[1]} % {self.tp_world_size_}"
73-
tp_size = tensor.shape[1] // self.tp_world_size_
74-
return tensor[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
75-
76-
def _slice_weight_scale(self, weight_scale: torch.Tensor):
77-
assert (
78-
weight_scale.shape[1] % self.tp_world_size_ == 0
79-
), f"tp slice error {weight_scale.shape[1]} % {self.tp_world_size_}"
80-
tp_size = weight_scale.shape[1] // self.tp_world_size_
81-
scale_start = tp_size * self.tp_rank_
82-
scale_end = tp_size * (self.tp_rank_ + 1)
83-
return weight_scale[:, scale_start:scale_end].to(torch.float)
84-
85-
def _process_weight_scale(self, weight_scale) -> None:
86-
self.weight_scale = weight_scale.cuda(get_current_device_id()).transpose(0, 1)
87-
88-
def _load_scales(self, weights: Dict[str, torch.Tensor]) -> None:
89-
if self.weight_scale_name in weights:
90-
weight_scale = self._slice_weight_scale(weights[self.weight_scale_name])
91-
self._process_weight_scale(weight_scale)
92-
if self.weight_scale is not None and isinstance(self.weight, torch.Tensor):
93-
# weight 中保存的 None 是为 激活静态量化 scale 预留的扩展位置。
94-
self.weight = [
95-
self.weight,
96-
self.weight_scale,
97-
None,
98-
]
99-
return
53+
self.param_slicer = QuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
10054

10155

10256
class AWQCOLMMWeight(AWQMMWeightTpl):
@@ -110,35 +64,8 @@ def __init__(
11064
tp_world_size: int = None,
11165
) -> None:
11266
super().__init__(data_type, quant_method, tp_rank, tp_world_size)
113-
self.weight_name = weight_name.replace("weight", quant_method.weight_suffix)
114-
self.weight_scale_name = weight_name.replace("weight", quant_method.weight_scale_suffix)
115-
self.weight_zero_point_name = weight_name.replace("weight", quant_method.weight_zero_point_suffix)
116-
self.bias_name = bias_name
117-
self.weight_scale: Optional[torch.Tensor] = None
118-
self.quantized_weight = True
119-
self.weight = [None, None, None]
120-
121-
def _slice_weight(self, weight: torch.Tensor):
122-
assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}"
123-
tp_size = weight.shape[0] // self.tp_world_size_
124-
return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1), :]
125-
126-
def _slice_bias(self, bias):
127-
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
128-
tp_size = bias.shape[0] // self.tp_world_size_
129-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1), :]
130-
131-
def _slice_weight_scale(self, weight_scale: torch.Tensor):
132-
tp_size = weight_scale.shape[0] // self.tp_world_size_
133-
scale_start = tp_size * self.tp_rank_
134-
scale_end = tp_size * (self.tp_rank_ + 1)
135-
return weight_scale[scale_start:scale_end, :].to(torch.half)
136-
137-
def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor):
138-
tp_size = weight_zero_point.shape[0] // self.tp_world_size_
139-
zero_point_start = tp_size * self.tp_rank_
140-
zero_point_end = tp_size * (self.tp_rank_ + 1)
141-
return weight_zero_point[zero_point_start:zero_point_end, :]
67+
# 注意这里不是错误,因为awq的weight是按inxout存的
68+
self.param_slicer = QuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
14269

14370

14471
class AWQMARLINCOLMMWeight(AWQCOLMMWeight):
@@ -151,7 +78,14 @@ def __init__(
15178
tp_rank: int = None,
15279
tp_world_size: int = None,
15380
) -> None:
154-
super().__init__(weight_name, data_type, bias_name, quant_method, tp_rank, tp_world_size)
81+
super().__init__(
82+
weight_name=weight_name,
83+
data_type=data_type,
84+
bias_name=bias_name,
85+
quant_method=quant_method,
86+
tp_rank=tp_rank,
87+
tp_world_size=tp_world_size,
88+
)
15589

15690
def _process_weight(self, weight: torch.Tensor) -> torch.Tensor:
15791
return self.quant_method._process_weight_after_loading(weight.cuda(get_current_device_id()))
@@ -168,7 +102,7 @@ def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.T
168102

169103

170104
COLBMM_WEIGHT_CLS_MAP = {
171-
"deepgemm-fp8w8a8-b128": W8A8B128COLMMWeight,
105+
"deepgemm-fp8w8a8-b128": DeepGemmFP8W8A8B128COLMMWeight,
172106
"awq": AWQCOLMMWeight,
173107
"awq_marlin": AWQMARLINCOLMMWeight,
174108
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from lightllm.common.quantization import Quantcfg
2+
from lightllm.common.quantization.quantize_method import QuantizationMethod
3+
from typing import Type, Union, Dict
4+
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_weight import (
5+
MMWeightTpl,
6+
MultiMMWeightTpl,
7+
BMMWeightTpl,
8+
)
9+
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.rowmm_weight import (
10+
UnquantizedROWMMWeight,
11+
UnquantizedROWBMMWeight,
12+
UnquantizedMultiROWMMWeight,
13+
ROWMM_WEIGHT_CLS_MAP,
14+
MULTI_ROWMM_WEIGHT_CLS_MAP,
15+
)
16+
from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.colmm_weight import (
17+
UnquantizedCOLMMWeight,
18+
COLBMM_WEIGHT_CLS_MAP,
19+
)
20+
21+
22+
class MMWeight:
23+
def __new__(cls, **kwargs):
24+
quant_cfg = kwargs.pop("quant_cfg", None)
25+
layer_num_ = kwargs.pop("layer_num", None)
26+
name = kwargs.pop("name", None)
27+
quant_method, quantized_weight = cls._get_quant_method(quant_cfg, layer_num_, name)
28+
kwargs["quant_method"] = quant_method
29+
mmcls = cls._get_mmcls(quant_method, quantized_weight)
30+
return mmcls(**kwargs)
31+
32+
@classmethod
33+
def _get_quant_method(cls, quant_cfg: Quantcfg, layer_num_: int, name: str) -> QuantizationMethod:
34+
if quant_cfg is None:
35+
return None, False
36+
quant_method = quant_cfg.get_quant_method(layer_num_, name)
37+
if quant_method is None:
38+
return None, False
39+
quant_method.hf_quantization_config = quant_cfg.hf_quantization_config
40+
quantized_weight = quant_cfg.quantized_weight
41+
return quant_method, quantized_weight
42+
43+
@classmethod
44+
def _get_mmcls(
45+
cls, quant_method: QuantizationMethod, quantized_weight: bool
46+
) -> Type[Union[MMWeightTpl, MultiMMWeightTpl, BMMWeightTpl]]:
47+
raise NotImplementedError("Subclasses must implement _get_mmcls method")
48+
49+
50+
class ROWMMWeight(MMWeight):
51+
@classmethod
52+
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
53+
if quant_method is None or not quantized_weight:
54+
return UnquantizedROWMMWeight
55+
56+
return ROWMM_WEIGHT_CLS_MAP[quant_method.method_name]
57+
58+
59+
class MultiROWMMWeight(MMWeight):
60+
@classmethod
61+
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
62+
if quant_method is None or not quantized_weight:
63+
return UnquantizedMultiROWMMWeight
64+
65+
return MULTI_ROWMM_WEIGHT_CLS_MAP[quant_method.method_name]
66+
67+
68+
class ROWBMMWeight(MMWeight):
69+
@classmethod
70+
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
71+
if quant_method is None or not quantized_weight:
72+
return UnquantizedROWBMMWeight
73+
else:
74+
# TODO: Implement more quantization weight
75+
raise NotImplementedError("ROWBMMWeight is not implemented")
76+
77+
78+
class COLMMWeight(MMWeight):
79+
@classmethod
80+
def _get_mmcls(cls, quant_method: QuantizationMethod, quantized_weight: bool):
81+
if quant_method is None or not quantized_weight:
82+
return UnquantizedCOLMMWeight
83+
return COLBMM_WEIGHT_CLS_MAP[quant_method.method_name]

0 commit comments

Comments
 (0)