11import torch
22from lightllm .common .basemodel .layer_weights .meta_weights .mm_weight .mm_weight import (
3- MMWeight ,
4- MMWeightTpl ,
5- generate_scale_name ,
3+ SingleMMWeightTpl ,
4+ DeepGemmFP8W8A8B128MMWeight ,
65 AWQMMWeightTpl ,
76)
87from lightllm .common .quantization import Quantcfg
98from lightllm .utils .dist_utils import get_current_device_id
109from lightllm .common .quantization .quantize_method import QuantizationMethod
1110from typing import Dict , List , Optional
11+ from .mm_slicer import ColSliceMixin , QuantizedRowSliceMixin , QuantizedColSliceMixin
1212
1313
14- class COLMMWeight (MMWeight ):
15- @classmethod
16- def _get_mmcls (cls , quant_method : QuantizationMethod , quantized_weight : bool ):
17- if quant_method is None or not quantized_weight :
18- return UnquantizedCOLMMWeight
19- return COLBMM_WEIGHT_CLS_MAP [quant_method .method_name ]
20-
21-
22- class UnquantizedCOLMMWeight (MMWeightTpl ):
14+ class UnquantizedCOLMMWeight (SingleMMWeightTpl ):
2315 def __init__ (
2416 self ,
2517 weight_name : str ,
@@ -29,24 +21,18 @@ def __init__(
2921 tp_rank : int = None ,
3022 tp_world_size : int = None ,
3123 ) -> None :
32- super ().__init__ (data_type , quant_method , tp_rank , tp_world_size )
33- self .weight_name = weight_name
34- self .bias_name = bias_name
35- self .has_bias = bias_name is not None
36-
37- def _slice_weight (self , tensor ):
38- assert tensor .shape [1 ] % self .tp_world_size_ == 0 , f"tp slice error { tensor .shape [1 ]} % { self .tp_world_size_ } "
39- tp_size = tensor .shape [1 ] // self .tp_world_size_
40- return tensor [:, tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )].to (self .data_type_ )
41-
42- def _slice_bias (self , bias ):
43- """
44- 因为 Colmm 列 tp 切分的计算,最后会有一个 reduce 操作,直接将 bias / tp_world_size 可以节省一步计算。
45- """
46- return (bias / self .tp_world_size_ ).to (self .data_type_ )
24+ super ().__init__ (
25+ weight_name = weight_name ,
26+ data_type = data_type ,
27+ bias_name = bias_name ,
28+ quant_method = quant_method ,
29+ tp_rank = tp_rank ,
30+ tp_world_size = tp_world_size ,
31+ )
32+ self .param_slicer = ColSliceMixin (tp_rank = tp_rank , tp_world_size = tp_world_size )
4733
4834
49- class W8A8B128COLMMWeight ( MMWeightTpl ):
35+ class DeepGemmFP8W8A8B128COLMMWeight ( DeepGemmFP8W8A8B128MMWeight ):
5036 def __init__ (
5137 self ,
5238 weight_name : str ,
@@ -56,47 +42,15 @@ def __init__(
5642 tp_rank : int = None ,
5743 tp_world_size : int = None ,
5844 ) -> None :
59- super ().__init__ (data_type , quant_method , tp_rank , tp_world_size )
60- self . weight_name = weight_name
61- self . bias_name = bias_name
62- self . has_bias = bias_name is not None
63-
64- self . weight_scale_name , self . act_scale_name = generate_scale_name (
65- weight_name , quant_method . weight_scale_suffix , quant_method . act_scale_suffix
45+ super ().__init__ (
46+ weight_name = weight_name ,
47+ data_type = data_type ,
48+ bias_name = bias_name ,
49+ quant_method = quant_method ,
50+ tp_rank = tp_rank ,
51+ tp_world_size = tp_world_size ,
6652 )
67- self .weight_scale : Optional [torch .Tensor ] = None
68- self .block_size = self .quant_method .block_size
69- self .quantized_weight = True
70-
71- def _slice_weight (self , tensor ):
72- assert tensor .shape [1 ] % self .tp_world_size_ == 0 , f"tp slice error { tensor .shape [1 ]} % { self .tp_world_size_ } "
73- tp_size = tensor .shape [1 ] // self .tp_world_size_
74- return tensor [:, tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
75-
76- def _slice_weight_scale (self , weight_scale : torch .Tensor ):
77- assert (
78- weight_scale .shape [1 ] % self .tp_world_size_ == 0
79- ), f"tp slice error { weight_scale .shape [1 ]} % { self .tp_world_size_ } "
80- tp_size = weight_scale .shape [1 ] // self .tp_world_size_
81- scale_start = tp_size * self .tp_rank_
82- scale_end = tp_size * (self .tp_rank_ + 1 )
83- return weight_scale [:, scale_start :scale_end ].to (torch .float )
84-
85- def _process_weight_scale (self , weight_scale ) -> None :
86- self .weight_scale = weight_scale .cuda (get_current_device_id ()).transpose (0 , 1 )
87-
88- def _load_scales (self , weights : Dict [str , torch .Tensor ]) -> None :
89- if self .weight_scale_name in weights :
90- weight_scale = self ._slice_weight_scale (weights [self .weight_scale_name ])
91- self ._process_weight_scale (weight_scale )
92- if self .weight_scale is not None and isinstance (self .weight , torch .Tensor ):
93- # weight 中保存的 None 是为 激活静态量化 scale 预留的扩展位置。
94- self .weight = [
95- self .weight ,
96- self .weight_scale ,
97- None ,
98- ]
99- return
53+ self .param_slicer = QuantizedColSliceMixin (tp_rank = tp_rank , tp_world_size = tp_world_size )
10054
10155
10256class AWQCOLMMWeight (AWQMMWeightTpl ):
@@ -110,35 +64,8 @@ def __init__(
11064 tp_world_size : int = None ,
11165 ) -> None :
11266 super ().__init__ (data_type , quant_method , tp_rank , tp_world_size )
113- self .weight_name = weight_name .replace ("weight" , quant_method .weight_suffix )
114- self .weight_scale_name = weight_name .replace ("weight" , quant_method .weight_scale_suffix )
115- self .weight_zero_point_name = weight_name .replace ("weight" , quant_method .weight_zero_point_suffix )
116- self .bias_name = bias_name
117- self .weight_scale : Optional [torch .Tensor ] = None
118- self .quantized_weight = True
119- self .weight = [None , None , None ]
120-
121- def _slice_weight (self , weight : torch .Tensor ):
122- assert weight .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { weight .shape [0 ]} % { self .tp_world_size_ } "
123- tp_size = weight .shape [0 ] // self .tp_world_size_
124- return weight [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 ), :]
125-
126- def _slice_bias (self , bias ):
127- assert bias .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { bias .shape [0 ]} % { self .tp_world_size_ } "
128- tp_size = bias .shape [0 ] // self .tp_world_size_
129- return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 ), :]
130-
131- def _slice_weight_scale (self , weight_scale : torch .Tensor ):
132- tp_size = weight_scale .shape [0 ] // self .tp_world_size_
133- scale_start = tp_size * self .tp_rank_
134- scale_end = tp_size * (self .tp_rank_ + 1 )
135- return weight_scale [scale_start :scale_end , :].to (torch .half )
136-
137- def _slice_weight_zero_point (self , weight_zero_point : torch .Tensor ):
138- tp_size = weight_zero_point .shape [0 ] // self .tp_world_size_
139- zero_point_start = tp_size * self .tp_rank_
140- zero_point_end = tp_size * (self .tp_rank_ + 1 )
141- return weight_zero_point [zero_point_start :zero_point_end , :]
67+ # 注意这里不是错误,因为awq的weight是按inxout存的
68+ self .param_slicer = QuantizedRowSliceMixin (tp_rank = tp_rank , tp_world_size = tp_world_size )
14269
14370
14471class AWQMARLINCOLMMWeight (AWQCOLMMWeight ):
@@ -151,7 +78,14 @@ def __init__(
15178 tp_rank : int = None ,
15279 tp_world_size : int = None ,
15380 ) -> None :
154- super ().__init__ (weight_name , data_type , bias_name , quant_method , tp_rank , tp_world_size )
81+ super ().__init__ (
82+ weight_name = weight_name ,
83+ data_type = data_type ,
84+ bias_name = bias_name ,
85+ quant_method = quant_method ,
86+ tp_rank = tp_rank ,
87+ tp_world_size = tp_world_size ,
88+ )
15589
15690 def _process_weight (self , weight : torch .Tensor ) -> torch .Tensor :
15791 return self .quant_method ._process_weight_after_loading (weight .cuda (get_current_device_id ()))
@@ -168,7 +102,7 @@ def _process_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.T
168102
169103
170104COLBMM_WEIGHT_CLS_MAP = {
171- "deepgemm-fp8w8a8-b128" : W8A8B128COLMMWeight ,
105+ "deepgemm-fp8w8a8-b128" : DeepGemmFP8W8A8B128COLMMWeight ,
172106 "awq" : AWQCOLMMWeight ,
173107 "awq_marlin" : AWQMARLINCOLMMWeight ,
174108}
0 commit comments