Skip to content

Commit d8b0589

Browse files
authored
[None][perf] Adjust select_alltoall_method_type. (#8950)
Signed-off-by: Bo Li <[email protected]>
1 parent 46dd988 commit d8b0589

File tree

7 files changed

+62
-48
lines changed

7 files changed

+62
-48
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ namespace tensorrt_llm::kernels::mnnvl_throughput
5151
__VA_ARGS__; \
5252
break; \
5353
} \
54+
case 6: \
55+
{ \
56+
constexpr int TOP_K = 6; \
57+
__VA_ARGS__; \
58+
break; \
59+
} \
5460
case 4: \
5561
{ \
5662
constexpr int TOP_K = 4; \

cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace torch_ext
3131
namespace mnnvl_throughput
3232
{
3333

34-
// TODO: Is Alignment necessary?obu guo
34+
// TODO: Is Alignment necessary?
3535
// Helper function to align offset to specified byte boundary
3636
inline size_t alignOffset(size_t offset, size_t alignment)
3737
{

cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ class Bf16MxE2m1BlockScaleMoeRunner : public torch::CustomClassHolder
554554
topk_group, intermediate_size, valid_hidden_size, valid_intermediate_size, local_expert_offset,
555555
local_num_experts, routed_scaling_factor, tileN, routing_method_type, mDtypeAct, *mRunners[tileN], config,
556556
topk_weights, topk_ids,
557-
/*output=*/torch::nullopt); // TODO: Support user-provided output
557+
/*out_tensor=*/torch::nullopt); // TODO: Support user-provided output
558558
}
559559

560560
private:

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def __init__(
151151
model_config.mapping)
152152
elif self.moe_alltoall_backend == "mnnvlthroughput":
153153
workspace_mb = int(
154-
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512"))
154+
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
155155
self.moe_a2a = MoeAlltoAll(
156156
mapping=self.mapping,
157157
max_num_tokens=model_config.max_num_tokens,
@@ -213,6 +213,17 @@ def has_int8_woq_per_channel(self):
213213
) and not self.quant_config.layer_quant_mode.has_per_group_scaling()
214214

215215
def select_alltoall_method_type(self) -> AlltoallMethodType:
216+
# If no attention DP, no need to use AlltoAll.
217+
if self.mapping.dp_size == 1:
218+
return AlltoallMethodType.NotEnabled
219+
220+
# AlltoAll cannot support MoE TP.
221+
if self.mapping.moe_tp_size != 1:
222+
return AlltoallMethodType.NotEnabled
223+
224+
if not MnnvlMemory.supports_mnnvl():
225+
return AlltoallMethodType.NotEnabled
226+
216227
all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
217228
if all2all_method_type is not None:
218229
if AlltoallMethodType[all2all_method_type] in [
@@ -224,18 +235,13 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
224235
)
225236
return AlltoallMethodType[all2all_method_type]
226237

227-
if not self.mapping.enable_attention_dp:
228-
return AlltoallMethodType.NotEnabled
229-
230-
if self.mapping.tp_size == 1:
231-
return AlltoallMethodType.NotEnabled
232-
233238
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
234239
return AlltoallMethodType.NotEnabled
235240

236-
if not (self.mapping.moe_ep_size > self.routing_method.experts_per_token
237-
and MnnvlMemory.supports_mnnvl()):
238-
return AlltoallMethodType.NotEnabled
241+
# TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter,
242+
# regardless of the relationship between EP size and topK. We favor AlltoAll for now.
243+
# if not self.mapping.moe_ep_size > self.routing_method.experts_per_token:
244+
# return AlltoallMethodType.NotEnabled
239245

240246
return AlltoallMethodType.MNNVL
241247

@@ -247,9 +253,9 @@ def enable_alltoall(self):
247253

248254
@cached_property
249255
def moe_alltoall_backend(self):
250-
# "mnnvllatency" (default) or "mnnvlthroughput"
256+
# "mnnvlthroughput" (default) or "mnnvllatency"
251257
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
252-
"mnnvllatency").strip().lower()
258+
"mnnvlthroughput").strip().lower()
253259

254260
def _supports_load_balancer(self) -> bool:
255261
"""CutlassFusedMoE supports load balancer."""
@@ -751,25 +757,15 @@ def forward_fake(
751757
use_dp_padding: Optional[bool] = None,
752758
**kwargs,
753759
) -> Union[torch.Tensor, List[torch.Tensor]]:
754-
if not self.enable_alltoall:
755-
return super().forward_fake(
756-
x,
757-
router_logits,
758-
do_finalize=do_finalize,
759-
output_dtype=output_dtype,
760-
all_rank_num_tokens=all_rank_num_tokens,
761-
use_dp_padding=use_dp_padding,
762-
**kwargs,
763-
)
764-
else:
765-
is_nvfp4_input = isinstance(x, Fp4QuantizedTensor)
766-
data_type = output_dtype if is_nvfp4_input else x.dtype
767-
num_tokens = all_rank_num_tokens[
768-
self.parallel_rank] if all_rank_num_tokens else x.shape[0]
769-
hidden_size = x.shape[1] * (2 if is_nvfp4_input else 1)
770-
top_k = self.routing_method.experts_per_token
771-
return x.new_empty((num_tokens, top_k, hidden_size),
772-
dtype=data_type)
760+
return super().forward_fake(
761+
x,
762+
router_logits,
763+
do_finalize=do_finalize,
764+
output_dtype=output_dtype,
765+
all_rank_num_tokens=all_rank_num_tokens,
766+
use_dp_padding=use_dp_padding,
767+
**kwargs,
768+
)
773769

774770
def load_weights(self, weights: List[Dict]):
775771
assert self._weights_created

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ...model_config import ModelConfig
1414
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
1515
from .fused_moe_cutlass import CutlassFusedMoE
16+
from .interface import AlltoallMethodType
1617
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
1718
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
1819
from .routing import BaseMoeRoutingMethod
@@ -462,6 +463,10 @@ def _get_quant_method(self):
462463
else:
463464
return UnquantizedFusedMoEMethod()
464465

466+
def select_alltoall_method_type(self) -> AlltoallMethodType:
467+
"""DeepGEMM backend currently doesn't support alltoall; honor overrides but default to disabled."""
468+
return AlltoallMethodType.NotEnabled
469+
465470
@nvtx_range("[DG] forward")
466471
def forward_chunk(
467472
self,

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(
128128
model_config.mapping)
129129
elif self.moe_alltoall_backend == "mnnvlthroughput":
130130
workspace_mb = int(
131-
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "512"))
131+
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
132132
self.moe_a2a = MoeAlltoAll(
133133
mapping=self.mapping,
134134
max_num_tokens=model_config.max_num_tokens,
@@ -154,6 +154,17 @@ def __init__(
154154
self.create_weights()
155155

156156
def select_alltoall_method_type(self) -> AlltoallMethodType:
157+
# If no attention DP, no need to use AlltoAll.
158+
if self.mapping.dp_size == 1:
159+
return AlltoallMethodType.NotEnabled
160+
161+
# AlltoAll cannot support MoE TP.
162+
if self.mapping.moe_tp_size != 1:
163+
return AlltoallMethodType.NotEnabled
164+
165+
if not MnnvlMemory.supports_mnnvl():
166+
return AlltoallMethodType.NotEnabled
167+
157168
all2all_method_type = os.environ.get("TRTLLM_FORCE_ALLTOALL_METHOD")
158169
if all2all_method_type is not None:
159170
if AlltoallMethodType[all2all_method_type] in [
@@ -165,18 +176,13 @@ def select_alltoall_method_type(self) -> AlltoallMethodType:
165176
)
166177
return AlltoallMethodType[all2all_method_type]
167178

168-
if not self.mapping.enable_attention_dp:
169-
return AlltoallMethodType.NotEnabled
170-
171-
if self.mapping.tp_size == 1:
172-
return AlltoallMethodType.NotEnabled
173-
174179
if os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") == "1":
175180
return AlltoallMethodType.NotEnabled
176181

177-
if not (self.mapping.moe_ep_size > self.routing_method.experts_per_token
178-
and MnnvlMemory.supports_mnnvl()):
179-
return AlltoallMethodType.NotEnabled
182+
# TODO: We found that MNNVL performs better than NCCL AllGather/ReduceScatter,
183+
# regardless of the relationship between EP size and topK. We favor AlltoAll for now.
184+
# if not self.mapping.moe_ep_size > self.routing_method.experts_per_token:
185+
# return AlltoallMethodType.NotEnabled
180186

181187
return AlltoallMethodType.MNNVL
182188

@@ -192,9 +198,9 @@ def enable_alltoall(self):
192198

193199
@cached_property
194200
def moe_alltoall_backend(self):
195-
# "mnnvllatency" (default) or "mnnvlthroughput"
201+
# "mnnvlthroughput" (default) or "mnnvllatency"
196202
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
197-
"mnnvllatency").strip().lower()
203+
"mnnvlthroughput").strip().lower()
198204

199205
def _check_configs(self):
200206
assert self.has_deepseek_fp8_block_scales \
@@ -503,7 +509,8 @@ def forward_impl(
503509

504510
moe_output: Optional[torch.Tensor] = None
505511
use_workspace_output = False
506-
if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput":
512+
# TODO: use_workspace_output only supports w4a8_mxfp4_mxfp8 (gpt-oss) for now
513+
if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput" and self.has_w4a8_mxfp4_mxfp8:
507514
moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace(
508515
runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16)
509516
use_workspace_output = True

tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_serve.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _run_serve_with_click(args):
1919
raise SystemExit(result.exit_code)
2020

2121

22-
@pytest.mark.timeout(360)
22+
@pytest.mark.timeout(500)
2323
def test_trtllm_serve_openai_chat_completion(tmp_path):
2424
# Prepare small model config and extra options yaml
2525
config = get_small_model_config("meta-llama/Meta-Llama-3.1-8B-Instruct")
@@ -58,7 +58,7 @@ def test_trtllm_serve_openai_chat_completion(tmp_path):
5858

5959
start_time = time.time()
6060
last_err = None
61-
while time.time() - start_time < 90:
61+
while time.time() - start_time < 300:
6262
if not server.is_alive():
6363
raise RuntimeError("Server process exited prematurely")
6464
try:

0 commit comments

Comments
 (0)