Skip to content

Commit fcfec93

Browse files
authored
[TRTLLM-9389][chore] Rename AlltoAll backend names (#9329)
Signed-off-by: Bo Li <[email protected]>
1 parent e1c9aa7 commit fcfec93

File tree

10 files changed

+51
-51
lines changed

10 files changed

+51
-51
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include <cstdint>
2424
#include <type_traits>
2525

26-
namespace tensorrt_llm::kernels::mnnvl_throughput
26+
namespace tensorrt_llm::kernels::moe_comm
2727
{
2828

2929
#define ENABLE_DEBUG_PRINT 0
@@ -964,4 +964,4 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv
964964
expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id);
965965
}
966966

967-
} // namespace tensorrt_llm::kernels::mnnvl_throughput
967+
} // namespace tensorrt_llm::kernels::moe_comm

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <cuda_bf16.h>
2020
#include <cuda_fp16.h>
2121

22-
namespace tensorrt_llm::kernels::mnnvl_throughput
22+
namespace tensorrt_llm::kernels::moe_comm
2323
{
2424

2525
// Configuration constants
@@ -177,4 +177,4 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params);
177177
void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, int32_t invalid_id,
178178
int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream);
179179

180-
} // namespace tensorrt_llm::kernels::mnnvl_throughput
180+
} // namespace tensorrt_llm::kernels::moe_comm

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace tensorrt_llm::nanobind::thop
3030
void initBindings(nb::module_& m)
3131
{
3232
// Export MoE A2A constants
33-
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
33+
for (auto const& kv : torch_ext::moe_comm::getMoeA2AMetaInfoIndexPairs())
3434
{
3535
m.attr(kv.first) = kv.second;
3636
}

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace tensorrt_llm::pybind::thop
3030
void initBindings(pybind11::module_& m)
3131
{
3232
// Export MoE A2A constants
33-
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
33+
for (auto const& kv : torch_ext::moe_comm::getMoeA2AMetaInfoIndexPairs())
3434
{
3535
m.attr(kv.first) = py::int_(kv.second);
3636
}

cpp/tensorrt_llm/thop/moeAlltoAllMeta.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
namespace torch_ext
2525
{
26-
namespace mnnvl_throughput
26+
namespace moe_comm
2727
{
2828

2929
// Enum for indexing into moe_a2a_metainfo tensor
@@ -61,5 +61,5 @@ inline std::vector<std::pair<char const*, int64_t>> getMoeA2AMetaInfoIndexPairs(
6161
};
6262
}
6363

64-
} // namespace mnnvl_throughput
64+
} // namespace moe_comm
6565
} // namespace torch_ext

cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
namespace torch_ext
2929
{
3030

31-
namespace mnnvl_throughput
31+
namespace moe_comm
3232
{
3333

3434
// TODO: Is Alignment necessary?
@@ -78,13 +78,13 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
7878
// topk_target_ranks: [maxNumTokens, kMaxTopK]
7979
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
8080
offsets[TOPK_TARGET_RANKS_OFFSET_INDEX] = offset;
81-
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK)
81+
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK)
8282
* SIZEOF_INT32;
8383

8484
// topk_send_indices: [maxNumTokens, kMaxTopK]
8585
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
8686
offsets[TOPK_SEND_INDICES_OFFSET_INDEX] = offset;
87-
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK)
87+
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK)
8888
* SIZEOF_INT32;
8989

9090
// payload data
@@ -165,11 +165,11 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
165165
std::vector<torch::Tensor> const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo,
166166
int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts)
167167
{
168-
using tensorrt_llm::kernels::mnnvl_throughput::PayloadDescriptor;
169-
using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ADispatchParams;
170-
using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_dispatch_launch;
171-
using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK;
172-
using tensorrt_llm::kernels::mnnvl_throughput::kMaxPayloads;
168+
using tensorrt_llm::kernels::moe_comm::PayloadDescriptor;
169+
using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams;
170+
using tensorrt_llm::kernels::moe_comm::moe_a2a_dispatch_launch;
171+
using tensorrt_llm::kernels::moe_comm::kMaxTopK;
172+
using tensorrt_llm::kernels::moe_comm::kMaxPayloads;
173173

174174
// Validate inputs
175175
CHECK_INPUT(tokenSelectedExperts, torch::kInt32);
@@ -344,9 +344,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
344344
torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK,
345345
int64_t combinePayloadOffset, bool payloadInWorkspace)
346346
{
347-
using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ACombineParams;
348-
using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_combine_launch;
349-
using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK;
347+
using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams;
348+
using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch;
349+
using tensorrt_llm::kernels::moe_comm::kMaxTopK;
350350

351351
// Validate inputs
352352
CHECK_TH_CUDA(payload);
@@ -474,8 +474,8 @@ void moeA2ASanitizeExpertIdsOp(torch::Tensor& expert_ids, torch::Tensor& workspa
474474
uint8_t* rankWorkSpacePtr = workspace.data_ptr<uint8_t>() + epRank * workspace.stride(0);
475475
int* recv_counters = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]);
476476

477-
tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr<int32_t>(),
478-
recv_counters, static_cast<int32_t>(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k,
477+
tensorrt_llm::kernels::moe_comm::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr<int32_t>(), recv_counters,
478+
static_cast<int32_t>(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k,
479479
at::cuda::getCurrentCUDAStream());
480480
}
481481

@@ -508,7 +508,7 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in
508508
return t;
509509
}
510510

511-
} // namespace mnnvl_throughput
511+
} // namespace moe_comm
512512

513513
} // namespace torch_ext
514514

@@ -540,9 +540,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)
540540

541541
TORCH_LIBRARY_IMPL(trtllm, CUDA, module)
542542
{
543-
module.impl("moe_a2a_dispatch", &torch_ext::mnnvl_throughput::moeA2ADispatchOp);
544-
module.impl("moe_a2a_combine", &torch_ext::mnnvl_throughput::moeA2ACombineOp);
545-
module.impl("moe_a2a_initialize", &torch_ext::mnnvl_throughput::moeA2AInitializeOp);
546-
module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::mnnvl_throughput::moeA2ASanitizeExpertIdsOp);
547-
module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::mnnvl_throughput::moeA2AGetCombinePayloadTensorOp);
543+
module.impl("moe_a2a_dispatch", &torch_ext::moe_comm::moeA2ADispatchOp);
544+
module.impl("moe_a2a_combine", &torch_ext::moe_comm::moeA2ACombineOp);
545+
module.impl("moe_a2a_initialize", &torch_ext::moe_comm::moeA2AInitializeOp);
546+
module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::moe_comm::moeA2ASanitizeExpertIdsOp);
547+
module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::moe_comm::moeA2AGetCombinePayloadTensorOp);
548548
}

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ def __init__(
143143
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
144144

145145
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
146-
if self.moe_alltoall_backend == "mnnvllatency":
146+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
147147
MnnvlMemory.initialize()
148148
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
149149
model_config.mapping)
150150
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
151151
model_config.mapping)
152-
elif self.moe_alltoall_backend == "mnnvlthroughput":
152+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
153153
workspace_mb = int(
154154
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
155155
self.moe_a2a = MoeAlltoAll(
@@ -253,9 +253,9 @@ def enable_alltoall(self):
253253

254254
@cached_property
255255
def moe_alltoall_backend(self):
256-
# "mnnvlthroughput" (default) or "mnnvllatency"
256+
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
257257
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
258-
"mnnvlthroughput").strip().lower()
258+
"NVLINK_ONE_SIDED").strip().upper()
259259

260260
def _supports_load_balancer(self) -> bool:
261261
"""CutlassFusedMoE supports load balancer."""
@@ -328,7 +328,7 @@ def forward_chunk(
328328

329329
if self.layer_load_balancer:
330330
self._load_balancer_done_wait_gpu_stage(is_first_call)
331-
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
331+
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
332332
self._load_balancer_update_statistic(
333333
token_selected_experts,
334334
is_first_call,
@@ -439,7 +439,7 @@ def forward_chunk(
439439
token_final_scales = torch.ones_like(token_selected_slots,
440440
dtype=torch.float32)
441441

442-
if self.moe_alltoall_backend == "mnnvllatency":
442+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
443443
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
444444
if is_last_call:
445445
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
@@ -472,7 +472,7 @@ def forward_chunk(
472472
token_selected_slots, alltoall_info.recv_rank_count_cumsum,
473473
runtime_max_tokens_per_rank, top_k, self.num_slots,
474474
self.ep_size)
475-
elif self.moe_alltoall_backend == "mnnvlthroughput":
475+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
476476
# Python MoeAlltoAll path
477477
if x_sf is not None:
478478
x_sf = x_sf.view(x_row,
@@ -532,7 +532,7 @@ def forward_chunk(
532532

533533
# Optionally provide an output tensor to fused_moe so it writes directly to our buffer
534534
moe_output: Optional[torch.Tensor] = None
535-
if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput":
535+
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
536536
# Retrieve a workspace-backed output tensor sized by runtime tokens
537537
runtime_max_tokens_per_rank = max(
538538
all_rank_num_tokens) if all_rank_num_tokens else x.shape[0]
@@ -583,7 +583,7 @@ def forward_chunk(
583583

584584
# Combine results if using alltoall
585585
if self.enable_alltoall:
586-
if self.moe_alltoall_backend == "mnnvllatency":
586+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
587587
if alltoall_info is not None:
588588
top_k = self.routing_method.experts_per_token
589589
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
@@ -596,7 +596,7 @@ def forward_chunk(
596596
use_low_precision_combine=self.
597597
use_low_precision_combine,
598598
token_count=token_count)
599-
elif self.moe_alltoall_backend == "mnnvlthroughput":
599+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
600600
output_hidden_size = final_hidden_states.shape[-1]
601601
runtime_max_tokens_per_rank = max(
602602
all_rank_num_tokens) if all_rank_num_tokens else token_count

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,13 @@ def __init__(
120120
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
121121

122122
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
123-
if self.moe_alltoall_backend == "mnnvllatency":
123+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
124124
MnnvlMemory.initialize()
125125
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
126126
model_config.mapping)
127127
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
128128
model_config.mapping)
129-
elif self.moe_alltoall_backend == "mnnvlthroughput":
129+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
130130
workspace_mb = int(
131131
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
132132
self.moe_a2a = MoeAlltoAll(
@@ -198,9 +198,9 @@ def enable_alltoall(self):
198198

199199
@cached_property
200200
def moe_alltoall_backend(self):
201-
# "mnnvlthroughput" (default) or "mnnvllatency"
201+
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
202202
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
203-
"mnnvlthroughput").strip().lower()
203+
"NVLINK_ONE_SIDED").strip().upper()
204204

205205
def _check_configs(self):
206206
assert self.has_deepseek_fp8_block_scales \
@@ -362,7 +362,7 @@ def forward_impl(
362362

363363
self._load_balancer_done_wait_gpu_stage(is_first_call)
364364

365-
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
365+
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
366366
self._load_balancer_update_statistic(
367367
token_selected_experts,
368368
is_first_call,
@@ -394,7 +394,7 @@ def forward_impl(
394394
else:
395395
token_final_scales = token_final_scales.to(torch.float32)
396396

397-
if self.moe_alltoall_backend == "mnnvllatency":
397+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
398398
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
399399
if is_last_call:
400400
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
@@ -444,7 +444,7 @@ def forward_impl(
444444

445445
if token_final_scales is not None:
446446
token_final_scales = token_final_scales.to(torch.bfloat16)
447-
elif self.moe_alltoall_backend == "mnnvlthroughput":
447+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
448448
if x_sf is not None:
449449
x_sf = x_sf.view(x_row,
450450
ceil_div(x_col, self.scaling_vector_size))
@@ -510,7 +510,7 @@ def forward_impl(
510510
moe_output: Optional[torch.Tensor] = None
511511
use_workspace_output = False
512512
# TODO: use_workspace_output only supports w4a8_mxfp4_mxfp8 (gpt-oss) for now
513-
if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput" and self.has_w4a8_mxfp4_mxfp8:
513+
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED" and self.has_w4a8_mxfp4_mxfp8:
514514
moe_output = self.moe_a2a.get_combine_payload_tensor_in_workspace(
515515
runtime_max_tokens_per_rank, self.hidden_size, torch.bfloat16)
516516
use_workspace_output = True
@@ -774,7 +774,7 @@ def forward_impl(
774774

775775
# Combine results if using alltoall
776776
if self.enable_alltoall:
777-
if self.moe_alltoall_backend == "mnnvllatency":
777+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
778778
if alltoall_info is not None:
779779
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
780780
final_hidden_states,
@@ -787,7 +787,7 @@ def forward_impl(
787787
use_low_precision_combine,
788788
token_count=token_count,
789789
)
790-
elif self.moe_alltoall_backend == "mnnvlthroughput":
790+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
791791
# If use_workspace_output=True, the MoE result is already in workspace
792792
# Otherwise, we need to reshape and pass it
793793
if use_workspace_output:

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ def enable_alltoall(self):
248248

249249
@cached_property
250250
def moe_alltoall_backend(self):
251-
# "mnnvllatency" (default) or "mnnvlthroughput"
251+
# "NVLINK_TWO_SIDED" (default) or "NVLINK_ONE_SIDED"
252252
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
253-
"mnnvllatency").strip().lower()
253+
"NVLINK_TWO_SIDED").strip().upper()
254254

255255
def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
256256
num_rows = sum(all_rank_num_tokens)
@@ -436,7 +436,7 @@ def forward_chunk(
436436

437437
if self.layer_load_balancer:
438438
self._load_balancer_done_wait_gpu_stage(is_first_call)
439-
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
439+
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
440440
self._load_balancer_update_statistic(token_selected_experts,
441441
is_first_call, is_last_call,
442442
ignore_allreduce)

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def _load_balancer_update_statistic(self,
346346
token_selected_experts: The selected experts of all tokens, has shape of [tokenCount * topK]
347347
is_first_call: Whether this is the first call for the same weights
348348
is_last_call: Whether this is the last call for the same weights
349-
ignore_allreduce: Whether to ignore allreduce, if True, only update local statistics, need call _load_balancer_get_local_statistic_tensor to get the local statistic tensor and then do external allgather and then call _load_balancer_update_statistic_with_gathered_statistic to update the global statistics. MnnvlLatency supports this.
349+
ignore_allreduce: Whether to ignore allreduce, if True, only update local statistics, need call _load_balancer_get_local_statistic_tensor to get the local statistic tensor and then do external allgather and then call _load_balancer_update_statistic_with_gathered_statistic to update the global statistics. NVLINK_TWO_SIDED supports this.
350350
"""
351351
if self._using_dynamic_load_balancer():
352352
if ignore_allreduce:

0 commit comments

Comments
 (0)