Skip to content

Commit 8a69425

Browse files
authored
Merge branch 'main' into add_ci_test
2 parents 06f9bb3 + 34a6d2d commit 8a69425

File tree

21 files changed

+323
-328
lines changed

21 files changed

+323
-328
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include <cstdint>
2424
#include <type_traits>
2525

26-
namespace tensorrt_llm::kernels::mnnvl_throughput
26+
namespace tensorrt_llm::kernels::moe_comm
2727
{
2828

2929
#define ENABLE_DEBUG_PRINT 0
@@ -964,4 +964,4 @@ void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv
964964
expert_ids, recv_counters, ep_size, max_tokens_per_rank, top_k, invalid_id);
965965
}
966966

967-
} // namespace tensorrt_llm::kernels::mnnvl_throughput
967+
} // namespace tensorrt_llm::kernels::moe_comm

cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <cuda_bf16.h>
2020
#include <cuda_fp16.h>
2121

22-
namespace tensorrt_llm::kernels::mnnvl_throughput
22+
namespace tensorrt_llm::kernels::moe_comm
2323
{
2424

2525
// Configuration constants
@@ -177,4 +177,4 @@ void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params);
177177
void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, int32_t invalid_id,
178178
int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream);
179179

180-
} // namespace tensorrt_llm::kernels::mnnvl_throughput
180+
} // namespace tensorrt_llm::kernels::moe_comm

cpp/tensorrt_llm/nanobind/thop/bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace tensorrt_llm::nanobind::thop
3030
void initBindings(nb::module_& m)
3131
{
3232
// Export MoE A2A constants
33-
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
33+
for (auto const& kv : torch_ext::moe_comm::getMoeA2AMetaInfoIndexPairs())
3434
{
3535
m.attr(kv.first) = kv.second;
3636
}

cpp/tensorrt_llm/pybind/thop/bindings.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ namespace tensorrt_llm::pybind::thop
3030
void initBindings(pybind11::module_& m)
3131
{
3232
// Export MoE A2A constants
33-
for (auto const& kv : torch_ext::mnnvl_throughput::getMoeA2AMetaInfoIndexPairs())
33+
for (auto const& kv : torch_ext::moe_comm::getMoeA2AMetaInfoIndexPairs())
3434
{
3535
m.attr(kv.first) = py::int_(kv.second);
3636
}

cpp/tensorrt_llm/thop/moeAlltoAllMeta.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
namespace torch_ext
2525
{
26-
namespace mnnvl_throughput
26+
namespace moe_comm
2727
{
2828

2929
// Enum for indexing into moe_a2a_metainfo tensor
@@ -61,5 +61,5 @@ inline std::vector<std::pair<char const*, int64_t>> getMoeA2AMetaInfoIndexPairs(
6161
};
6262
}
6363

64-
} // namespace mnnvl_throughput
64+
} // namespace moe_comm
6565
} // namespace torch_ext

cpp/tensorrt_llm/thop/moeAlltoAllOp.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
namespace torch_ext
2929
{
3030

31-
namespace mnnvl_throughput
31+
namespace moe_comm
3232
{
3333

3434
// TODO: Is Alignment necessary?
@@ -78,13 +78,13 @@ MoeA2ADataOffsets calculateOffsets(int epSize, int maxNumTokens)
7878
// topk_target_ranks: [maxNumTokens, kMaxTopK]
7979
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
8080
offsets[TOPK_TARGET_RANKS_OFFSET_INDEX] = offset;
81-
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK)
81+
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK)
8282
* SIZEOF_INT32;
8383

8484
// topk_send_indices: [maxNumTokens, kMaxTopK]
8585
offset = alignOffset(offset, CACHELINE_ALIGNMENT);
8686
offsets[TOPK_SEND_INDICES_OFFSET_INDEX] = offset;
87-
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK)
87+
offset += static_cast<size_t>(maxNumTokens) * static_cast<size_t>(tensorrt_llm::kernels::moe_comm::kMaxTopK)
8888
* SIZEOF_INT32;
8989

9090
// payload data
@@ -165,11 +165,11 @@ std::tuple<std::vector<torch::Tensor>, int64_t> moeA2ADispatchOp(torch::Tensor c
165165
std::vector<torch::Tensor> const& inputPayloads, torch::Tensor const& workspace, torch::Tensor const& metainfo,
166166
int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK, int64_t numExperts)
167167
{
168-
using tensorrt_llm::kernels::mnnvl_throughput::PayloadDescriptor;
169-
using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ADispatchParams;
170-
using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_dispatch_launch;
171-
using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK;
172-
using tensorrt_llm::kernels::mnnvl_throughput::kMaxPayloads;
168+
using tensorrt_llm::kernels::moe_comm::PayloadDescriptor;
169+
using tensorrt_llm::kernels::moe_comm::MoeA2ADispatchParams;
170+
using tensorrt_llm::kernels::moe_comm::moe_a2a_dispatch_launch;
171+
using tensorrt_llm::kernels::moe_comm::kMaxTopK;
172+
using tensorrt_llm::kernels::moe_comm::kMaxPayloads;
173173

174174
// Validate inputs
175175
CHECK_INPUT(tokenSelectedExperts, torch::kInt32);
@@ -344,9 +344,9 @@ torch::Tensor moeA2ACombineOp(torch::Tensor const& payload, int64_t localNumToke
344344
torch::Tensor const& metainfo, int64_t runtimeMaxTokensPerRank, int64_t epRank, int64_t epSize, int64_t topK,
345345
int64_t combinePayloadOffset, bool payloadInWorkspace)
346346
{
347-
using tensorrt_llm::kernels::mnnvl_throughput::MoeA2ACombineParams;
348-
using tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_combine_launch;
349-
using tensorrt_llm::kernels::mnnvl_throughput::kMaxTopK;
347+
using tensorrt_llm::kernels::moe_comm::MoeA2ACombineParams;
348+
using tensorrt_llm::kernels::moe_comm::moe_a2a_combine_launch;
349+
using tensorrt_llm::kernels::moe_comm::kMaxTopK;
350350

351351
// Validate inputs
352352
CHECK_TH_CUDA(payload);
@@ -474,8 +474,8 @@ void moeA2ASanitizeExpertIdsOp(torch::Tensor& expert_ids, torch::Tensor& workspa
474474
uint8_t* rankWorkSpacePtr = workspace.data_ptr<uint8_t>() + epRank * workspace.stride(0);
475475
int* recv_counters = reinterpret_cast<int*>(rankWorkSpacePtr + offsets[RECV_COUNTERS_OFFSET_INDEX]);
476476

477-
tensorrt_llm::kernels::mnnvl_throughput::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr<int32_t>(),
478-
recv_counters, static_cast<int32_t>(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k,
477+
tensorrt_llm::kernels::moe_comm::moe_a2a_sanitize_expert_ids_launch(expert_ids.data_ptr<int32_t>(), recv_counters,
478+
static_cast<int32_t>(invalid_expert_id), ep_size, runtime_max_tokens_per_rank, top_k,
479479
at::cuda::getCurrentCUDAStream());
480480
}
481481

@@ -508,7 +508,7 @@ torch::Tensor moeA2AGetCombinePayloadTensorOp(torch::Tensor const& workspace, in
508508
return t;
509509
}
510510

511-
} // namespace mnnvl_throughput
511+
} // namespace moe_comm
512512

513513
} // namespace torch_ext
514514

@@ -540,9 +540,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, module)
540540

541541
TORCH_LIBRARY_IMPL(trtllm, CUDA, module)
542542
{
543-
module.impl("moe_a2a_dispatch", &torch_ext::mnnvl_throughput::moeA2ADispatchOp);
544-
module.impl("moe_a2a_combine", &torch_ext::mnnvl_throughput::moeA2ACombineOp);
545-
module.impl("moe_a2a_initialize", &torch_ext::mnnvl_throughput::moeA2AInitializeOp);
546-
module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::mnnvl_throughput::moeA2ASanitizeExpertIdsOp);
547-
module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::mnnvl_throughput::moeA2AGetCombinePayloadTensorOp);
543+
module.impl("moe_a2a_dispatch", &torch_ext::moe_comm::moeA2ADispatchOp);
544+
module.impl("moe_a2a_combine", &torch_ext::moe_comm::moeA2ACombineOp);
545+
module.impl("moe_a2a_initialize", &torch_ext::moe_comm::moeA2AInitializeOp);
546+
module.impl("moe_a2a_sanitize_expert_ids", &torch_ext::moe_comm::moeA2ASanitizeExpertIdsOp);
547+
module.impl("moe_a2a_get_combine_payload_tensor", &torch_ext::moe_comm::moeA2AGetCombinePayloadTensorOp);
548548
}

tensorrt_llm/_torch/autotuner.py

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ class TuningConfig:
9999
constraint_specs: Tuple[ConstraintSpec, ...] = ()
100100
tune_max_num_tokens: int = None
101101
inputs_pre_hook: Callable = None
102-
use_cuda_graph: bool = False
102+
use_cuda_graph: bool = True
103103

104104

105105
@dataclass(unsafe_hash=True)
@@ -526,7 +526,7 @@ class AutoTuner:
526526
_CUDA_GRAPH_DELAY_MICRO_SECS = 100
527527
_instance = None
528528

529-
def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
529+
def __init__(self, warmup=2, repeat=10, stream_delay_micro_secs=1000):
530530
self.repeat = repeat
531531
self.warmup = warmup
532532
self.stream_delay_micro_secs = stream_delay_micro_secs
@@ -698,23 +698,25 @@ def choose_one(
698698
})
699699

700700
input_shapes = tuple(self._get_input_sizes(inputs))
701+
is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache(
702+
custom_op, runners, input_shapes, tuning_config)
703+
701704
# Early return if it's not tuning, use cache found one or fallback one
702705
if not self.is_tuning_mode:
703-
is_cache_hit, best_runner_id, best_tactic, min_time = self.profiling_cache.search_cache(
704-
custom_op, runners, input_shapes, tuning_config)
705706
best_runner = runners[best_runner_id]
706707
# TODO: check the stored runner and tactic can implement this shape here
707-
# Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf.
708-
709-
# Record the cache miss config.
710-
# Expect no cache miss in inference. Thus, any cache miss should be recorded.
708+
# Log the cache miss. Expect no cache miss in inference.
711709
if not is_cache_hit:
712710
logger.warning_once(
713711
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
714712
key=(custom_op, "warning_autotuning_cache_miss_fallback"))
715713

716714
return (best_runner, best_tactic)
717715

716+
# If it's tuning mode and cache hit, return the best runner and tactic to avoid redundant profiling.
717+
if self.is_tuning_mode and is_cache_hit:
718+
return (runners[best_runner_id], best_tactic)
719+
718720
assert len(runners) > 0, "At least one runner is required"
719721
assert all([isinstance(r, TunableRunner) for r in runners]), \
720722
"All Given runners must be subclass of TunableRunner"
@@ -881,43 +883,62 @@ def _profile_single_kernel(
881883
are used to ensure accurate timing.
882884
"""
883885
stream = torch.cuda.current_stream()
884-
graph = torch.cuda.CUDAGraph()
885-
start = torch.cuda.Event(enable_timing=True)
886-
end = torch.cuda.Event(enable_timing=True)
887-
888-
with torch.cuda.stream(stream):
889-
# warm up, no timing
890-
for _ in range(self.warmup):
891-
runner(inputs, tactic=tactic, **kwargs)
892-
893-
if use_cuda_graph:
894-
with torch.cuda.graph(graph):
895-
for _ in range(self.repeat):
896-
runner(inputs, tactic=tactic, **kwargs)
886+
# If the warm up time is longer than 0.5ms, we will profile the kernel with fewer repeats.
887+
profile_fewer_repeat = 2
888+
short_profile_threshold_ms = 1
889+
890+
avg_time = float('inf')
891+
892+
def pure_profile(stream: torch.cuda.Stream, repeat: int):
893+
start = torch.cuda.Event(enable_timing=True)
894+
end = torch.cuda.Event(enable_timing=True)
895+
graph = torch.cuda.CUDAGraph()
896+
897+
with torch.cuda.stream(stream):
898+
if use_cuda_graph:
899+
with torch.cuda.graph(graph):
900+
for _ in range(repeat):
901+
runner(inputs, tactic=tactic, **kwargs)
902+
903+
stream.synchronize()
904+
905+
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
906+
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
907+
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
908+
if use_cuda_graph:
909+
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
910+
else:
911+
delay_kernel(self.stream_delay_micro_secs, stream)
897912

898-
stream.synchronize()
913+
start.record()
899914

900-
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
901-
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
902-
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
903-
if use_cuda_graph:
904-
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
905-
else:
906-
delay_kernel(self.stream_delay_micro_secs, stream)
915+
if use_cuda_graph:
916+
graph.replay()
917+
else:
918+
for _ in range(repeat):
919+
runner(inputs, tactic=tactic, **kwargs)
907920

908-
start.record()
921+
end.record()
922+
stream.synchronize()
909923

910-
if use_cuda_graph:
911-
graph.replay()
912-
else:
913-
for _ in range(self.repeat):
914-
runner(inputs, tactic=tactic, **kwargs)
924+
return start.elapsed_time(end) / repeat
915925

916-
end.record()
926+
for _ in range(self.warmup):
927+
runner(inputs, tactic=tactic, **kwargs)
917928

918-
stream.synchronize()
929+
fewer_repeat_avg_time = pure_profile(stream, profile_fewer_repeat)
919930

920-
avg_time = start.elapsed_time(end) / self.repeat
931+
disable_short_profile = os.environ.get(
932+
"TLLM_AUTOTUNER_DISABLE_SHORT_PROFILE", "0") == "1"
933+
if fewer_repeat_avg_time > short_profile_threshold_ms and not disable_short_profile:
934+
print(
935+
f"[Autotuner] Few repeat estimated time is longer than {short_profile_threshold_ms}ms, directly use the few repeat estimated time to avoid redundant profiling."
936+
)
937+
# directly use the few repeat estimated time to avoid redundant profiling
938+
avg_time = fewer_repeat_avg_time
939+
else:
940+
# profile the kernel with the full repeat to get precise time
941+
avg_time = pure_profile(stream, self.repeat)
921942

922943
shapes = self._get_input_sizes(inputs)
923944
logger.debug(

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class CuteDSLNVFP4BlackwellRunner(TunableRunner):
4040
0, 0, get_last_power_of_2_num_tokens_buckets,
4141
last_positive_power_of_2), ),
4242
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
43-
use_cuda_graph=True,
4443
)
4544

4645
def __init__(self, alpha: float, output_dtype: torch.dtype):

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ def __init__(
143143
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
144144

145145
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
146-
if self.moe_alltoall_backend == "mnnvllatency":
146+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
147147
MnnvlMemory.initialize()
148148
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
149149
model_config.mapping)
150150
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
151151
model_config.mapping)
152-
elif self.moe_alltoall_backend == "mnnvlthroughput":
152+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
153153
workspace_mb = int(
154154
os.environ.get("TRTLLM_MOE_A2A_WORKSPACE_MB", "2048"))
155155
self.moe_a2a = MoeAlltoAll(
@@ -253,9 +253,9 @@ def enable_alltoall(self):
253253

254254
@cached_property
255255
def moe_alltoall_backend(self):
256-
# "mnnvlthroughput" (default) or "mnnvllatency"
256+
# "NVLINK_ONE_SIDED" (default) or "NVLINK_TWO_SIDED"
257257
return os.environ.get("TRTLLM_MOE_ALLTOALL_BACKEND",
258-
"mnnvlthroughput").strip().lower()
258+
"NVLINK_ONE_SIDED").strip().upper()
259259

260260
def _supports_load_balancer(self) -> bool:
261261
"""CutlassFusedMoE supports load balancer."""
@@ -328,7 +328,7 @@ def forward_chunk(
328328

329329
if self.layer_load_balancer:
330330
self._load_balancer_done_wait_gpu_stage(is_first_call)
331-
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "mnnvllatency"
331+
ignore_allreduce = self.enable_alltoall and self.alltoall_method_type == AlltoallMethodType.MNNVL and self.moe_alltoall_backend == "NVLINK_TWO_SIDED"
332332
self._load_balancer_update_statistic(
333333
token_selected_experts,
334334
is_first_call,
@@ -439,7 +439,7 @@ def forward_chunk(
439439
token_final_scales = torch.ones_like(token_selected_slots,
440440
dtype=torch.float32)
441441

442-
if self.moe_alltoall_backend == "mnnvllatency":
442+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
443443
assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
444444
if is_last_call:
445445
loadbalancer_local_statistic_info = self._load_balancer_get_local_statistic_tensor(
@@ -472,7 +472,7 @@ def forward_chunk(
472472
token_selected_slots, alltoall_info.recv_rank_count_cumsum,
473473
runtime_max_tokens_per_rank, top_k, self.num_slots,
474474
self.ep_size)
475-
elif self.moe_alltoall_backend == "mnnvlthroughput":
475+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
476476
# Python MoeAlltoAll path
477477
if x_sf is not None:
478478
x_sf = x_sf.view(x_row,
@@ -532,7 +532,7 @@ def forward_chunk(
532532

533533
# Optionally provide an output tensor to fused_moe so it writes directly to our buffer
534534
moe_output: Optional[torch.Tensor] = None
535-
if self.enable_alltoall and self.moe_alltoall_backend == "mnnvlthroughput":
535+
if self.enable_alltoall and self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
536536
# Retrieve a workspace-backed output tensor sized by runtime tokens
537537
runtime_max_tokens_per_rank = max(
538538
all_rank_num_tokens) if all_rank_num_tokens else x.shape[0]
@@ -583,7 +583,7 @@ def forward_chunk(
583583

584584
# Combine results if using alltoall
585585
if self.enable_alltoall:
586-
if self.moe_alltoall_backend == "mnnvllatency":
586+
if self.moe_alltoall_backend == "NVLINK_TWO_SIDED":
587587
if alltoall_info is not None:
588588
top_k = self.routing_method.experts_per_token
589589
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
@@ -596,7 +596,7 @@ def forward_chunk(
596596
use_low_precision_combine=self.
597597
use_low_precision_combine,
598598
token_count=token_count)
599-
elif self.moe_alltoall_backend == "mnnvlthroughput":
599+
elif self.moe_alltoall_backend == "NVLINK_ONE_SIDED":
600600
output_hidden_size = final_hidden_states.shape[-1]
601601
runtime_max_tokens_per_rank = max(
602602
all_rank_num_tokens) if all_rank_num_tokens else token_count

0 commit comments

Comments
 (0)