Skip to content

Commit 9241cca

Browse files
authored
[None][feat] Enable EPLB for trtllm-gen and cutlass backend (#8886)
Signed-off-by: Dongxu Yang <[email protected]>
1 parent 5f26c31 commit 9241cca

17 files changed

+871
-334
lines changed

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ def __init__(
140140
self.config = config # Store config as instance variable
141141
pretrained_config = config.pretrained_config
142142
self.num_experts = pretrained_config.num_local_experts
143+
moe_load_balancer_config = config.moe_load_balancer
144+
self.num_slots = moe_load_balancer_config.num_slots if moe_load_balancer_config and moe_load_balancer_config.num_slots else self.num_experts
145+
143146
self.layer_idx = layer_idx
144147
self.enable_attention_dp = config.mapping.enable_attention_dp
145148
self.mapping = config.mapping
@@ -162,13 +165,13 @@ def __init__(
162165
if config.moe_backend.upper() == "TRTLLM" else torch.float32)
163166

164167
self.swiglu_alpha = torch.tensor(
165-
[1.702] * (self.num_experts // config.mapping.moe_ep_size),
168+
[1.702] * (self.num_slots // config.mapping.moe_ep_size),
166169
dtype=torch.float32).cuda()
167170
self.swiglu_beta = torch.tensor(
168-
[1.0] * (self.num_experts // config.mapping.moe_ep_size),
171+
[1.0] * (self.num_slots // config.mapping.moe_ep_size),
169172
dtype=torch.float32).cuda()
170173
self.swiglu_limit = torch.tensor(
171-
[7.0] * (self.num_experts // config.mapping.moe_ep_size),
174+
[7.0] * (self.num_slots // config.mapping.moe_ep_size),
172175
dtype=torch.float32).cuda()
173176
# Prepare MoE creation parameters
174177
moe_params = {

tensorrt_llm/_torch/modules/fused_moe/create_moe.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def create_moe(
7979

8080
moe_load_balancer = get_moe_load_balancer()
8181
if moe_load_balancer is not None:
82-
assert moe_cls == WideEPMoE, "MoE Load Balance is only supported in WideEPMoE now."
82+
assert moe_cls in [
83+
WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE
84+
], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE and TRTLLMGenFusedMoE now."
8385

8486
if bias:
8587
assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE
@@ -106,6 +108,7 @@ def create_moe(
106108
dtype=dtype,
107109
reduce_results=reduce_results,
108110
model_config=model_config,
111+
aux_stream_dict=aux_stream_dict,
109112
weight_loading_mode=weight_loading_mode,
110113
bias=bias,
111114
layer_idx=layer_idx,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,13 @@ def __init__(
141141
)
142142

143143
def forward_chunk(
144-
self,
145-
x: Union[torch.Tensor, Fp4QuantizedTensor],
146-
router_logits: torch.Tensor,
147-
output_dtype: Optional[torch.dtype] = None,
148-
all_rank_num_tokens: Optional[List[int]] = None,
149-
use_dp_padding: Optional[bool] = None,
144+
self,
145+
x: Union[torch.Tensor, Fp4QuantizedTensor],
146+
router_logits: torch.Tensor,
147+
output_dtype: Optional[torch.dtype] = None,
148+
all_rank_num_tokens: Optional[List[int]] = None,
149+
use_dp_padding: Optional[bool] = None,
150+
repeating_info: tuple = (True, True),
150151
) -> torch.Tensor:
151152
if isinstance(x, Fp4QuantizedTensor):
152153
assert output_dtype is not None

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 101 additions & 59 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from ...custom_ops.trtllm_gen_custom_ops import \
1414
fp4_block_scale_fake_output_without_finalize
1515
from ...distributed import allgather
16+
from ...expert_statistic import ExpertStatistic
1617
from ...model_config import ModelConfig
17-
from ...utils import Fp4QuantizedTensor, ceil_div
18+
from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div
1819
from .interface import AlltoallMethodType, MoE, MoEWeightLoadingMode
1920
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
2021
NVFP4TRTLLMGenFusedMoEMethod,
@@ -37,6 +38,7 @@ class TRTLLMGenFusedMoE(MoE):
3738
dtype (Optional[torch.dtype]): Data type for the weights.
3839
reduce_results (bool): Whether to reduce the results across devices.
3940
model_config (ModelConfig): Configuration object for the model.
41+
aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
4042
4143
MoE torch custom op:
4244
Only support min-latency mode now (SM100 Blackwell only).
@@ -66,6 +68,8 @@ def __init__(
6668
dtype: Optional[torch.dtype] = None,
6769
reduce_results: bool = False,
6870
model_config: ModelConfig = ModelConfig(),
71+
aux_stream_dict: Optional[Dict[AuxStreamType,
72+
torch.cuda.Stream]] = None,
6973
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
7074
VANILLA,
7175
layer_idx: Optional[int] = None,
@@ -82,6 +86,7 @@ def __init__(
8286
dtype=dtype,
8387
reduce_results=reduce_results,
8488
model_config=model_config,
89+
aux_stream_dict=aux_stream_dict,
8590
weight_loading_mode=weight_loading_mode,
8691
bias=bias,
8792
swiglu_alpha=swiglu_alpha,
@@ -97,19 +102,11 @@ def __init__(
97102

98103
assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE."
99104

100-
self.num_slots = self.num_experts
101-
self.expert_size_per_partition = self.num_experts // self.ep_size
102-
self.initial_global_assignments = [
103-
(ep_rank * self.num_experts // self.ep_size + local_slot_id) %
104-
self.num_experts for ep_rank in range(self.ep_size)
105-
for local_slot_id in range(self.expert_size_per_partition)
106-
]
107-
self.slot_start = self.ep_rank * self.expert_size_per_partition
108-
self.slot_end = self.slot_start + self.expert_size_per_partition
109-
self.initial_local_expert_ids = self.initial_global_assignments[
110-
self.slot_start:self.slot_end]
111-
assert len(
112-
self.initial_local_expert_ids) == self.expert_size_per_partition
105+
# Note: Load balancer initialization is handled by base class _init_load_balancer()
106+
# If no load balancer is available, the base class will set:
107+
# - self.num_slots = self.num_experts
108+
# - self.expert_size_per_partition = self.num_experts // self.ep_size
109+
# - self.initial_global_assignments, self.slot_start, self.slot_end, etc.
113110

114111
# TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future.
115112
self.alltoall_method_type = self.select_alltoall_method_type()
@@ -136,7 +133,7 @@ def __init__(
136133
mapping=self.mapping,
137134
max_num_tokens=model_config.max_num_tokens,
138135
top_k=self.routing_method.experts_per_token,
139-
num_experts=self.num_experts,
136+
num_experts=self.num_slots,
140137
workspace_size_per_rank=workspace_mb * 1024 * 1024,
141138
)
142139
else:
@@ -183,6 +180,10 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
183180

184181
return AlltoallMethodType.MNNVL
185182

183+
def _supports_load_balancer(self) -> bool:
184+
"""TRTLLMGenFusedMoE supports load balancer."""
185+
return True
186+
186187
@cached_property
187188
def enable_alltoall(self):
188189
""" enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
@@ -340,14 +341,39 @@ def forward_impl(
340341
x_col = x.shape[1]
341342
token_count = x.shape[0]
342343
alltoall_info = None
344+
# Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking)
345+
is_first_call = self.repeat_idx == 0
346+
is_last_call = self.repeat_idx == self.repeat_count - 1
343347

344348
if post_quant_comm:
349+
# Start GPU stage for first call
350+
self._load_balancer_start_wait_gpu_stage(is_first_call)
345351
token_selected_experts, token_final_scales = self.routing_method.apply(
346352
router_logits)
347353
token_selected_experts = token_selected_experts.to(torch.int32)
348354
if token_final_scales is not None:
349355
token_final_scales = token_final_scales.to(torch.bfloat16)
350356

357+
self._load_balancer_done_wait_gpu_stage(is_first_call)
358+
359+
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
360+
self._load_balancer_update_statistic(
361+
token_selected_experts,
362+
is_first_call,
363+
is_last_call,
364+
ignore_allreduce=ignore_allreduce)
365+
366+
# Route tokens to slots
367+
token_selected_slots = self._load_balancer_route(
368+
token_selected_experts, self.use_dp)
369+
370+
# Update expert statistics
371+
ExpertStatistic.set_layer(self.layer_idx)
372+
ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots)
373+
374+
# Use routed slots for subsequent processing
375+
token_selected_experts = token_selected_slots
376+
351377
x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x)
352378

353379
if self.enable_alltoall:
@@ -364,9 +390,14 @@ def forward_impl(
364390

365391
if self.moe_alltoall_backend == "mnnvllatency":
366392
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
367-
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
393+
if is_last_call:
394+
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
395+
)
396+
else:
397+
loadbalancer_local_statistic_info = None
398+
alltoall_info, gathered_loadbalancer_local_statistic_info = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
368399
token_selected_experts,
369-
None,
400+
loadbalancer_local_statistic_info,
370401
self.alltoall_prepare_workspace,
371402
runtime_max_tokens_per_rank,
372403
self.ep_rank,
@@ -375,6 +406,11 @@ def forward_impl(
375406
self.num_slots,
376407
top_k,
377408
)
409+
if gathered_loadbalancer_local_statistic_info is not None:
410+
gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info.view(
411+
(self.mapping.moe_ep_size, self.num_experts))
412+
self._load_balancer_update_statistic_with_gathered_statistic(
413+
gathered_loadbalancer_local_statistic_info)
378414

379415
if x_sf is not None:
380416
x_sf = x_sf.view(x_row,
@@ -716,6 +752,9 @@ def forward_impl(
716752
"TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
717753
)
718754

755+
# Handle load balancer CPU stage if needed
756+
self._load_balancer_start_set_cpu_stage(is_last_call)
757+
719758
# Combine results if using alltoall
720759
if self.enable_alltoall:
721760
if self.moe_alltoall_backend == "mnnvllatency":
@@ -763,10 +802,17 @@ def forward_impl(
763802
use_dp_padding=use_dp_padding,
764803
)
765804

805+
self._load_balancer_done_set_cpu_stage(is_last_call)
806+
766807
if use_dp_padding:
767808
rank = self.mapping.tp_rank
768809
final_hidden_states = final_hidden_states[:
769810
all_rank_num_tokens[rank]]
811+
812+
# Update repeat index for load balancer
813+
if self.layer_load_balancer:
814+
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
815+
770816
return final_hidden_states
771817

772818
def forward_fake(

0 commit comments

Comments
 (0)