Skip to content

Commit 2fa25ff

Browse files
author
wangzaijun
committed
fix moe create way
1 parent 1a8d67c commit 2fa25ff

File tree

6 files changed

+54
-84
lines changed

6 files changed

+54
-84
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
ROWBMMWeight,
88
)
99
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
10-
from .fused_moe_weight_tp import FusedMoeWeightTP
10+
from .fused_moe_weight_tp import create_tp_moe_wegiht_obj
1111
from .fused_moe_weight_ep import FusedMoeWeightEP

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,60 @@
11
import os
22
import torch
33
import threading
4-
from typing import Optional, Tuple, List, Dict, Any
4+
from typing import Optional, Tuple, List, Dict, Any, Union
55
from .base_weight import BaseWeight
66
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id
77
from lightllm.common.quantization import Quantcfg
88

99

10-
class FusedMoeWeightTP:
11-
def __new__(
12-
cls,
13-
gate_proj_name: str,
14-
down_proj_name: str,
15-
up_proj_name: str,
16-
e_score_correction_bias_name: str,
17-
weight_prefix: str,
18-
n_routed_experts: int,
19-
num_fused_shared_experts: int,
20-
split_inter_size: int,
21-
data_type: torch.dtype,
22-
network_config: Dict[str, Any],
23-
layer_num: int,
24-
quant_cfg: Quantcfg = None,
25-
):
26-
quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
27-
if quant_method is not None and quant_method.method_name == "awq_marlin":
28-
return FusedAWQMARLINMoeWeightTP(
29-
gate_proj_name=gate_proj_name,
30-
down_proj_name=down_proj_name,
31-
up_proj_name=up_proj_name,
32-
e_score_correction_bias_name=e_score_correction_bias_name,
33-
weight_prefix=weight_prefix,
34-
n_routed_experts=n_routed_experts,
35-
num_fused_shared_experts=num_fused_shared_experts,
36-
split_inter_size=split_inter_size,
37-
data_type=data_type,
38-
network_config=network_config,
39-
layer_num=layer_num,
40-
quant_cfg=quant_cfg,
41-
)
42-
else:
43-
return FusedBaseMoeWeightTP(
44-
gate_proj_name=gate_proj_name,
45-
down_proj_name=down_proj_name,
46-
up_proj_name=up_proj_name,
47-
e_score_correction_bias_name=e_score_correction_bias_name,
48-
weight_prefix=weight_prefix,
49-
n_routed_experts=n_routed_experts,
50-
num_fused_shared_experts=num_fused_shared_experts,
51-
split_inter_size=split_inter_size,
52-
data_type=data_type,
53-
network_config=network_config,
54-
layer_num=layer_num,
55-
quant_cfg=quant_cfg,
56-
)
10+
def create_tp_moe_wegiht_obj(
11+
gate_proj_name: str,
12+
down_proj_name: str,
13+
up_proj_name: str,
14+
e_score_correction_bias_name: str,
15+
weight_prefix: str,
16+
n_routed_experts: int,
17+
num_fused_shared_experts: int,
18+
split_inter_size: int,
19+
data_type: torch.dtype,
20+
network_config: Dict[str, Any],
21+
layer_num: int,
22+
quant_cfg: Quantcfg = None,
23+
) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]:
24+
quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
25+
if quant_method is not None and quant_method.method_name == "awq_marlin":
26+
return FusedAWQMARLINMoeWeightTP(
27+
gate_proj_name=gate_proj_name,
28+
down_proj_name=down_proj_name,
29+
up_proj_name=up_proj_name,
30+
e_score_correction_bias_name=e_score_correction_bias_name,
31+
weight_prefix=weight_prefix,
32+
n_routed_experts=n_routed_experts,
33+
num_fused_shared_experts=num_fused_shared_experts,
34+
split_inter_size=split_inter_size,
35+
data_type=data_type,
36+
network_config=network_config,
37+
layer_num=layer_num,
38+
quant_cfg=quant_cfg,
39+
)
40+
else:
41+
return FusedMoeWeightTP(
42+
gate_proj_name=gate_proj_name,
43+
down_proj_name=down_proj_name,
44+
up_proj_name=up_proj_name,
45+
e_score_correction_bias_name=e_score_correction_bias_name,
46+
weight_prefix=weight_prefix,
47+
n_routed_experts=n_routed_experts,
48+
num_fused_shared_experts=num_fused_shared_experts,
49+
split_inter_size=split_inter_size,
50+
data_type=data_type,
51+
network_config=network_config,
52+
layer_num=layer_num,
53+
quant_cfg=quant_cfg,
54+
)
5755

5856

59-
class FusedBaseMoeWeightTP(BaseWeight):
57+
class FusedMoeWeightTP(BaseWeight):
6058
def __init__(
6159
self,
6260
gate_proj_name: str,

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
ROWMMWeight,
99
COLMMWeight,
1010
NormWeight,
11-
FusedMoeWeightTP,
1211
FusedMoeWeightEP,
1312
ROWBMMWeight,
13+
create_tp_moe_wegiht_obj,
1414
)
1515
from functools import partial
1616
from ..triton_kernel.weight_dequant import weight_dequant
@@ -265,7 +265,7 @@ def _init_moe(self):
265265
moe_mode = os.getenv("MOE_MODE", "TP")
266266
assert moe_mode in ["EP", "TP"]
267267
if moe_mode == "TP":
268-
self.experts = FusedMoeWeightTP(
268+
self.experts = create_tp_moe_wegiht_obj(
269269
gate_proj_name="gate_proj",
270270
down_proj_name="down_proj",
271271
up_proj_name="up_proj",

lightllm/models/mixtral/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,7 @@
22
from lightllm.utils.log_utils import init_logger
33
from lightllm.utils.envs_utils import enable_env_vars
44
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
5-
from lightllm.common.basemodel.layer_weights.meta_weights import (
6-
ROWMMWeight,
7-
FusedMoeWeightTP,
8-
FusedMoeWeightEP,
9-
)
5+
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj
106

117
logger = init_logger(__name__)
128

@@ -53,7 +49,7 @@ def _init_moe(self):
5349
assert moe_mode in ["TP"], f"Unsupported moe mode: {moe_mode}"
5450

5551
if moe_mode == "TP":
56-
self.experts = FusedMoeWeightTP(
52+
self.experts = create_tp_moe_wegiht_obj(
5753
gate_proj_name="w1",
5854
down_proj_name="w2",
5955
up_proj_name="w3",

lightllm/models/qwen3/layer_weights/transformer_layer_weight.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,7 @@
1-
import os
2-
import torch
3-
import math
4-
import numpy as np
5-
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
61
from lightllm.models.qwen2.layer_weights.transformer_layer_weight import Qwen2TransformerLayerWeight
72
from lightllm.common.basemodel.layer_weights.meta_weights import (
8-
ROWMMWeight,
9-
COLMMWeight,
103
NormWeight,
11-
FusedMoeWeightTP,
12-
FusedMoeWeightEP,
13-
ROWBMMWeight,
144
)
15-
from functools import partial
165

176

187
class Qwen3TransformerLayerWeight(Qwen2TransformerLayerWeight):

lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,6 @@
11
import os
2-
import torch
3-
import math
4-
import numpy as np
5-
from lightllm.common.basemodel import TransformerLayerWeight
62
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
7-
from lightllm.utils.envs_utils import enable_env_vars
8-
from lightllm.common.basemodel.layer_weights.meta_weights import (
9-
ROWMMWeight,
10-
COLMMWeight,
11-
NormWeight,
12-
FusedMoeWeightTP,
13-
FusedMoeWeightEP,
14-
ROWBMMWeight,
15-
)
16-
from functools import partial
3+
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj
174

185

196
class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight):
@@ -76,7 +63,7 @@ def _init_moe(self):
7663
moe_mode = os.getenv("MOE_MODE", "TP")
7764
assert moe_mode in ["EP", "TP"]
7865
if moe_mode == "TP":
79-
self.experts = FusedMoeWeightTP(
66+
self.experts = create_tp_moe_wegiht_obj(
8067
gate_proj_name="gate_proj",
8168
down_proj_name="down_proj",
8269
up_proj_name="up_proj",

0 commit comments

Comments
 (0)