Skip to content

Commit 8fd19c2

Browse files
hiworldwzjwangzaijun
andauthored
fix mm slicer (#1104)
Co-authored-by: wangzaijun <[email protected]>
1 parent 5f8d736 commit 8fd19c2

File tree

3 files changed

+38
-30
lines changed

3 files changed

+38
-30
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/colmm_weight.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightllm.utils.dist_utils import get_current_device_id
99
from lightllm.common.quantization.quantize_method import QuantizationMethod
1010
from typing import Dict, List, Optional, Union
11-
from .mm_slicer import ColSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin
11+
from .mm_slicer import ColSliceMixin, QuantizedColSliceMixin, AwqQuantizedColSliceMixin
1212

1313

1414
class StandardCOLMMWeight(MMWeightTpl):
@@ -72,9 +72,7 @@ def __init__(
7272
tp_world_size=tp_world_size,
7373
)
7474
# 注意这里不是错误,因为awq的weight是按inxout存的
75-
self.param_slicer = QuantizedRowSliceMixin(
76-
tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=True
77-
)
75+
self.param_slicer = AwqQuantizedColSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
7876

7977

8078
class AWQMARLINCOLMMWeight(AWQCOLMMWeight):

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
class SliceMixinBase(ABC):
88
"""切片操作的Mixin基类"""
99

10-
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
10+
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
1111
self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp()
1212
self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size()
13-
self.bias_div_world_size_ = bias_div_world_size
1413

1514
@abstractmethod
1615
def _slice_weight(self, weight: torch.Tensor):
@@ -22,8 +21,8 @@ def _slice_bias(self, bias):
2221

2322

2423
class SliceMixinTpl(SliceMixinBase):
25-
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
26-
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
24+
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
25+
super().__init__(tp_rank, tp_world_size)
2726

2827
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
2928
raise NotImplementedError("slice_weight must implement this method")
@@ -41,27 +40,25 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
4140
# 默认weight 的shape是 outxin,这也是目前最通用的约定。
4241
# 所以row-wise是沿着dim=0进行切分,col-wise是沿着dim=1进行切分。
4342
class RowSliceMixin(SliceMixinTpl):
44-
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
45-
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
43+
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
44+
super().__init__(tp_rank, tp_world_size)
4645

4746
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
4847
assert weight.shape[0] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[0]} % {self.tp_world_size_}"
4948
tp_size = weight.shape[0] // self.tp_world_size_
5049
return weight[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
5150

52-
def _slice_bias(self, bias) -> torch.Tensor:
51+
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
5352
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
5453
tp_size = bias.shape[0] // self.tp_world_size_
55-
if self.bias_div_world_size_:
56-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_
5754
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
5855

5956

6057
# 量化切片默认实现方式是group-wise的量化,所以weight_scale 和weight_zero_point ndims跟weight一样。
6158
# 后续按需要,扩展per-tensor、per-channel的量化方式。
6259
class QuantizedRowSliceMixin(RowSliceMixin):
63-
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = False):
64-
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
60+
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
61+
super().__init__(tp_rank, tp_world_size)
6562

6663
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
6764
assert (
@@ -83,25 +80,21 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
8380

8481

8582
class ColSliceMixin(SliceMixinTpl):
86-
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True):
87-
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
83+
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
84+
super().__init__(tp_rank, tp_world_size)
8885

8986
def _slice_weight(self, weight: torch.Tensor) -> torch.Tensor:
9087
assert weight.shape[1] % self.tp_world_size_ == 0, f"tp slice error {weight.shape[1]} % {self.tp_world_size_}"
9188
tp_size = weight.shape[1] // self.tp_world_size_
9289
return weight[:, tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
9390

94-
def _slice_bias(self, bias) -> torch.Tensor:
95-
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
96-
tp_size = bias.shape[0] // self.tp_world_size_
97-
if self.bias_div_world_size_:
98-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)] / self.tp_world_size_
99-
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
91+
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
92+
return bias / self.tp_world_size_
10093

10194

10295
class QuantizedColSliceMixin(ColSliceMixin):
103-
def __init__(self, tp_rank: int = None, tp_world_size: int = None, bias_div_world_size: bool = True):
104-
super().__init__(tp_rank, tp_world_size, bias_div_world_size)
96+
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
97+
super().__init__(tp_rank, tp_world_size)
10598

10699
def _slice_weight_scale(self, weight_scale: torch.Tensor) -> torch.Tensor:
107100
assert (
@@ -120,3 +113,22 @@ def _slice_weight_zero_point(self, weight_zero_point: torch.Tensor) -> torch.Ten
120113
zero_point_start = tp_size * self.tp_rank_
121114
zero_point_end = tp_size * (self.tp_rank_ + 1)
122115
return weight_zero_point[:, zero_point_start:zero_point_end]
116+
117+
118+
# awq 的量化权重是inxout存储格式,需要定制实现。
119+
class AwqQuantizedRowSliceMixin(QuantizedColSliceMixin):
120+
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
121+
super().__init__(tp_rank, tp_world_size)
122+
123+
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
124+
assert bias.shape[0] % self.tp_world_size_ == 0, f"tp slice error {bias.shape[0]} % {self.tp_world_size_}"
125+
tp_size = bias.shape[0] // self.tp_world_size_
126+
return bias[tp_size * self.tp_rank_ : tp_size * (self.tp_rank_ + 1)]
127+
128+
129+
class AwqQuantizedColSliceMixin(QuantizedRowSliceMixin):
130+
def __init__(self, tp_rank: int = None, tp_world_size: int = None):
131+
super().__init__(tp_rank, tp_world_size)
132+
133+
def _slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
134+
return bias / self.tp_world_size_

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from lightllm.utils.dist_utils import get_current_device_id
1010
from lightllm.common.quantization.quantize_method import QuantizationMethod
1111
from typing import Dict, List, Optional, Union
12-
from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, QuantizedColSliceMixin
12+
from .mm_slicer import RowSliceMixin, QuantizedRowSliceMixin, AwqQuantizedRowSliceMixin
1313

1414

1515
class StandardROWMMWeight(MMWeightTpl):
@@ -94,10 +94,8 @@ def __init__(
9494
tp_rank=tp_rank,
9595
tp_world_size=tp_world_size,
9696
)
97-
# 注意这里不是错误,因为awq的weight是按inxout存的
98-
self.param_slicer = QuantizedColSliceMixin(
99-
tp_rank=tp_rank, tp_world_size=tp_world_size, bias_div_world_size=False
100-
)
97+
98+
self.param_slicer = AwqQuantizedRowSliceMixin(tp_rank=tp_rank, tp_world_size=tp_world_size)
10199

102100

103101
class AWQMARLINROWMMWeight(AWQROWMMWeight):

0 commit comments

Comments
 (0)