1313from ...custom_ops .trtllm_gen_custom_ops import \
1414 fp4_block_scale_fake_output_without_finalize
1515from ...distributed import allgather
16+ from ...expert_statistic import ExpertStatistic
1617from ...model_config import ModelConfig
17- from ...utils import Fp4QuantizedTensor , ceil_div
18+ from ...utils import AuxStreamType , Fp4QuantizedTensor , ceil_div
1819from .interface import AlltoallMethodType , MoE , MoEWeightLoadingMode
1920from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod ,
2021 NVFP4TRTLLMGenFusedMoEMethod ,
@@ -37,6 +38,7 @@ class TRTLLMGenFusedMoE(MoE):
3738 dtype (Optional[torch.dtype]): Data type for the weights.
3839 reduce_results (bool): Whether to reduce the results across devices.
3940 model_config (ModelConfig): Configuration object for the model.
41+ aux_stream_dict (Optional[Dict[AuxStreamType, torch.cuda.Stream]]): Auxiliary CUDA streams for overlapping.
4042
4143 MoE torch custom op:
4244 Only support min-latency mode now (SM100 Blackwell only).
@@ -66,6 +68,8 @@ def __init__(
6668 dtype : Optional [torch .dtype ] = None ,
6769 reduce_results : bool = False ,
6870 model_config : ModelConfig = ModelConfig (),
71+ aux_stream_dict : Optional [Dict [AuxStreamType ,
72+ torch .cuda .Stream ]] = None ,
6973 weight_loading_mode : MoEWeightLoadingMode = MoEWeightLoadingMode .
7074 VANILLA ,
7175 layer_idx : Optional [int ] = None ,
@@ -82,6 +86,7 @@ def __init__(
8286 dtype = dtype ,
8387 reduce_results = reduce_results ,
8488 model_config = model_config ,
89+ aux_stream_dict = aux_stream_dict ,
8590 weight_loading_mode = weight_loading_mode ,
8691 bias = bias ,
8792 swiglu_alpha = swiglu_alpha ,
@@ -97,19 +102,11 @@ def __init__(
97102
98103 assert not self .smart_router , "Smart router is not supported in TRTLLMGenFusedMoE."
99104
100- self .num_slots = self .num_experts
101- self .expert_size_per_partition = self .num_experts // self .ep_size
102- self .initial_global_assignments = [
103- (ep_rank * self .num_experts // self .ep_size + local_slot_id ) %
104- self .num_experts for ep_rank in range (self .ep_size )
105- for local_slot_id in range (self .expert_size_per_partition )
106- ]
107- self .slot_start = self .ep_rank * self .expert_size_per_partition
108- self .slot_end = self .slot_start + self .expert_size_per_partition
109- self .initial_local_expert_ids = self .initial_global_assignments [
110- self .slot_start :self .slot_end ]
111- assert len (
112- self .initial_local_expert_ids ) == self .expert_size_per_partition
105+ # Note: Load balancer initialization is handled by base class _init_load_balancer()
106+ # If no load balancer is available, the base class will set:
107+ # - self.num_slots = self.num_experts
108+ # - self.expert_size_per_partition = self.num_experts // self.ep_size
109+ # - self.initial_global_assignments, self.slot_start, self.slot_end, etc.
113110
114111 # TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future.
115112 self .alltoall_method_type = self .select_alltoall_method_type ()
@@ -136,7 +133,7 @@ def __init__(
136133 mapping = self .mapping ,
137134 max_num_tokens = model_config .max_num_tokens ,
138135 top_k = self .routing_method .experts_per_token ,
139- num_experts = self .num_experts ,
136+ num_experts = self .num_slots ,
140137 workspace_size_per_rank = workspace_mb * 1024 * 1024 ,
141138 )
142139 else :
@@ -183,6 +180,10 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
183180
184181 return AlltoallMethodType .MNNVL
185182
183+ def _supports_load_balancer (self ) -> bool :
184+ """TRTLLMGenFusedMoE supports load balancer."""
185+ return True
186+
186187 @cached_property
187188 def enable_alltoall (self ):
188189 """ enable_alltoall (bool): whether to enable alltoall instead of allgather/reducescatter
@@ -340,14 +341,39 @@ def forward_impl(
340341 x_col = x .shape [1 ]
341342 token_count = x .shape [0 ]
342343 alltoall_info = None
344+ # Determine if this is first/last call (TRTLLMGenFusedMoE doesn't use chunking)
345+ is_first_call = self .repeat_idx == 0
346+ is_last_call = self .repeat_idx == self .repeat_count - 1
343347
344348 if post_quant_comm :
349+ # Start GPU stage for first call
350+ self ._load_balancer_start_wait_gpu_stage (is_first_call )
345351 token_selected_experts , token_final_scales = self .routing_method .apply (
346352 router_logits )
347353 token_selected_experts = token_selected_experts .to (torch .int32 )
348354 if token_final_scales is not None :
349355 token_final_scales = token_final_scales .to (torch .bfloat16 )
350356
357+ self ._load_balancer_done_wait_gpu_stage (is_first_call )
358+
359+ ignore_allreduce = self .enable_alltoall and self .alltoall_method_type == AlltoallMethodType .MNNVL and self .moe_alltoall_backend == "mnnvllatency"
360+ self ._load_balancer_update_statistic (
361+ token_selected_experts ,
362+ is_first_call ,
363+ is_last_call ,
364+ ignore_allreduce = ignore_allreduce )
365+
366+ # Route tokens to slots
367+ token_selected_slots = self ._load_balancer_route (
368+ token_selected_experts , self .use_dp )
369+
370+ # Update expert statistics
371+ ExpertStatistic .set_layer (self .layer_idx )
372+ ExpertStatistic .maybe_add_info (self .num_slots , token_selected_slots )
373+
374+ # Use routed slots for subsequent processing
375+ token_selected_experts = token_selected_slots
376+
351377 x , x_sf , x_row , x_col = self ._quantize_for_post_quant_comm (x )
352378
353379 if self .enable_alltoall :
@@ -364,9 +390,14 @@ def forward_impl(
364390
365391 if self .moe_alltoall_backend == "mnnvllatency" :
366392 assert self .alltoall_prepare_workspace is not None , "alltoall_prepare_workspace should be initialized"
367- alltoall_info , _ = MnnvlMoe .mnnvl_moe_alltoallv_prepare_without_allgather (
393+ if is_last_call :
394+ loadbalancer_local_statistic_info = self ._load_balancer_get_local_statistic_tensor (
395+ )
396+ else :
397+ loadbalancer_local_statistic_info = None
398+ alltoall_info , gathered_loadbalancer_local_statistic_info = MnnvlMoe .mnnvl_moe_alltoallv_prepare_without_allgather (
368399 token_selected_experts ,
369- None ,
400+ loadbalancer_local_statistic_info ,
370401 self .alltoall_prepare_workspace ,
371402 runtime_max_tokens_per_rank ,
372403 self .ep_rank ,
@@ -375,6 +406,11 @@ def forward_impl(
375406 self .num_slots ,
376407 top_k ,
377408 )
409+ if gathered_loadbalancer_local_statistic_info is not None :
410+ gathered_loadbalancer_local_statistic_info = gathered_loadbalancer_local_statistic_info .view (
411+ (self .mapping .moe_ep_size , self .num_experts ))
412+ self ._load_balancer_update_statistic_with_gathered_statistic (
413+ gathered_loadbalancer_local_statistic_info )
378414
379415 if x_sf is not None :
380416 x_sf = x_sf .view (x_row ,
@@ -716,6 +752,9 @@ def forward_impl(
716752 "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
717753 )
718754
755+ # Handle load balancer CPU stage if needed
756+ self ._load_balancer_start_set_cpu_stage (is_last_call )
757+
719758 # Combine results if using alltoall
720759 if self .enable_alltoall :
721760 if self .moe_alltoall_backend == "mnnvllatency" :
@@ -763,10 +802,17 @@ def forward_impl(
763802 use_dp_padding = use_dp_padding ,
764803 )
765804
805+ self ._load_balancer_done_set_cpu_stage (is_last_call )
806+
766807 if use_dp_padding :
767808 rank = self .mapping .tp_rank
768809 final_hidden_states = final_hidden_states [:
769810 all_rank_num_tokens [rank ]]
811+
812+ # Update repeat index for load balancer
813+ if self .layer_load_balancer :
814+ self .repeat_idx = 0 if self .repeat_idx == self .repeat_count - 1 else self .repeat_idx + 1
815+
770816 return final_hidden_states
771817
772818 def forward_fake (
0 commit comments