77class SliceMixinBase (ABC ):
88 """切片操作的Mixin基类"""
99
10- def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = False ):
10+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
1111 self .tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp ()
1212 self .tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size ()
13- self .bias_div_world_size_ = bias_div_world_size
1413
1514 @abstractmethod
1615 def _slice_weight (self , weight : torch .Tensor ):
@@ -22,8 +21,8 @@ def _slice_bias(self, bias):
2221
2322
2423class SliceMixinTpl (SliceMixinBase ):
25- def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = False ):
26- super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
24+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
25+ super ().__init__ (tp_rank , tp_world_size )
2726
2827 def _slice_weight (self , weight : torch .Tensor ) -> torch .Tensor :
2928 raise NotImplementedError ("slice_weight must implement this method" )
@@ -41,27 +40,25 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
4140# 默认weight 的shape是 outxin,这也是目前最通用的约定。
4241# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。
4342class RowSliceMixin (SliceMixinTpl ):
44- def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = False ):
45- super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
43+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
44+ super ().__init__ (tp_rank , tp_world_size )
4645
4746 def _slice_weight (self , weight : torch .Tensor ) -> torch .Tensor :
4847 assert weight .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { weight .shape [0 ]} % { self .tp_world_size_ } "
4948 tp_size = weight .shape [0 ] // self .tp_world_size_
5049 return weight [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
5150
52- def _slice_bias (self , bias ) -> torch .Tensor :
51+ def _slice_bias (self , bias : torch . Tensor ) -> torch .Tensor :
5352 assert bias .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { bias .shape [0 ]} % { self .tp_world_size_ } "
5453 tp_size = bias .shape [0 ] // self .tp_world_size_
55- if self .bias_div_world_size_ :
56- return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )] / self .tp_world_size_
5754 return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
5855
5956
6057# 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。
6158# 后续按需要,扩展per-tensor、per-channel的量化方式。
6259class QuantizedRowSliceMixin (RowSliceMixin ):
63- def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = False ):
64- super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
60+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
61+ super ().__init__ (tp_rank , tp_world_size )
6562
6663 def _slice_weight_scale (self , weight_scale : torch .Tensor ) -> torch .Tensor :
6764 assert (
@@ -83,25 +80,21 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
8380
8481
8582class ColSliceMixin (SliceMixinTpl ):
86- def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = True ):
87- super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
83+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
84+ super ().__init__ (tp_rank , tp_world_size )
8885
8986 def _slice_weight (self , weight : torch .Tensor ) -> torch .Tensor :
9087 assert weight .shape [1 ] % self .tp_world_size_ == 0 , f"tp slice error { weight .shape [1 ]} % { self .tp_world_size_ } "
9188 tp_size = weight .shape [1 ] // self .tp_world_size_
9289 return weight [:, tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
9390
94- def _slice_bias (self , bias ) -> torch .Tensor :
95- assert bias .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { bias .shape [0 ]} % { self .tp_world_size_ } "
96- tp_size = bias .shape [0 ] // self .tp_world_size_
97- if self .bias_div_world_size_ :
98- return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )] / self .tp_world_size_
99- return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
91+ def _slice_bias (self , bias : torch .Tensor ) -> torch .Tensor :
92+ return bias / self .tp_world_size_
10093
10194
10295class QuantizedColSliceMixin (ColSliceMixin ):
103- def __init__ (self , tp_rank : int = None , tp_world_size : int = None , bias_div_world_size : bool = True ):
104- super ().__init__ (tp_rank , tp_world_size , bias_div_world_size )
96+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
97+ super ().__init__ (tp_rank , tp_world_size )
10598
10699 def _slice_weight_scale (self , weight_scale : torch .Tensor ) -> torch .Tensor :
107100 assert (
@@ -120,3 +113,22 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
120113 zero_point_start = tp_size * self .tp_rank_
121114 zero_point_end = tp_size * (self .tp_rank_ + 1 )
122115 return weight_zero_point [:, zero_point_start :zero_point_end ]
116+
117+
118+ # awq 的量化权重是inxout存储格式,需要定制实现。
119+ class AwqQuantizedRowSliceMixin (QuantizedColSliceMixin ):
120+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
121+ super ().__init__ (tp_rank , tp_world_size )
122+
123+ def _slice_bias (self , bias : torch .Tensor ) -> torch .Tensor :
124+ assert bias .shape [0 ] % self .tp_world_size_ == 0 , f"tp slice error { bias .shape [0 ]} % { self .tp_world_size_ } "
125+ tp_size = bias .shape [0 ] // self .tp_world_size_
126+ return bias [tp_size * self .tp_rank_ : tp_size * (self .tp_rank_ + 1 )]
127+
128+
129+ class AwqQuantizedColSliceMixin (QuantizedRowSliceMixin ):
130+ def __init__ (self , tp_rank : int = None , tp_world_size : int = None ):
131+ super ().__init__ (tp_rank , tp_world_size )
132+
133+ def _slice_bias (self , bias : torch .Tensor ) -> torch .Tensor :
134+ return bias / self .tp_world_size_
0 commit comments