Skip to content

Commit b6bced8

Browse files
authored
[TRTLLM-7963][feat] Use CUDAGraph to improve the tuning accuracy for AutoTuner. (#9089)
Signed-off-by: Yukun He <[email protected]>
1 parent 41e5870 commit b6bced8

File tree

3 files changed

+45
-20
lines changed

3 files changed

+45
-20
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ class TuningConfig:
9999
constraint_specs: Tuple[ConstraintSpec, ...] = ()
100100
tune_max_num_tokens: int = None
101101
inputs_pre_hook: Callable = None
102+
use_cuda_graph: bool = False
102103

103104

104105
@dataclass(unsafe_hash=True)
@@ -522,6 +523,7 @@ class AutoTuner:
522523
repeat (int): Number of profiling iterations for averaging (default: 10)
523524
stream_delay_micro_secs (int): Delay on CUDA stream before the profiled kernel runs in microseconds (default: 1000)
524525
"""
526+
_CUDA_GRAPH_DELAY_MICRO_SECS = 100
525527
_instance = None
526528

527529
def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
@@ -534,8 +536,6 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
534536
# Add statistics tracking
535537
self.stats = AutoTunerStatistics()
536538

537-
self.profiling_debug = True
538-
539539
# Current captured choose_one() contexts
540540
self._active_capture: Optional['AutoTuner.TacticsCapture'] = None
541541
# Last captured choose_one() contexts
@@ -727,10 +727,10 @@ def choose_one(
727727
new_tuning_failure_occured = False
728728

729729
for p in profiles:
730+
tensors = self._prepare_input_tensors(p, inputs)
730731
is_cache_hit, *_ = self.profiling_cache.search_cache(
731732
custom_op, runners, p.get_opt_shapes(), tuning_config)
732733
if not is_cache_hit:
733-
tensors = self._prepare_input_tensors(p, inputs)
734734
# Initialize runner and tactic as None in case of no valid tactic or runners are found
735735
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = self._profile_runners(
736736
custom_op, runners, tensors, p, tuning_config, **kwargs)
@@ -811,7 +811,12 @@ def _profile_runners(
811811
for tac in valid_tactics:
812812
try:
813813
time_measured = self._profile_single_kernel(
814-
runner, input_tensors, tac, **kwargs)
814+
runner=runner,
815+
inputs=input_tensors,
816+
tactic=tac,
817+
use_cuda_graph=tuning_config.use_cuda_graph,
818+
**kwargs,
819+
)
815820
except Exception as e:
816821
# Handle None tensors for optional inputs
817822
shapes = self._get_input_sizes(input_tensors)
@@ -857,6 +862,7 @@ def _profile_single_kernel(
857862
runner: TunableRunner,
858863
inputs: List[torch.Tensor],
859864
tactic: Any,
865+
use_cuda_graph: bool = False,
860866
**kwargs,
861867
) -> float:
862868
"""Profile a single kernel implementation for performance measurement.
@@ -875,22 +881,40 @@ def _profile_single_kernel(
875881
are used to ensure accurate timing.
876882
"""
877883
stream = torch.cuda.current_stream()
878-
# warm up, no timing
879-
for _ in range(self.warmup):
880-
runner(inputs, tactic=tactic, **kwargs)
881-
stream.synchronize()
882-
883-
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
884-
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
885-
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
886-
delay_kernel(self.stream_delay_micro_secs, stream)
884+
graph = torch.cuda.CUDAGraph()
887885
start = torch.cuda.Event(enable_timing=True)
888886
end = torch.cuda.Event(enable_timing=True)
889887

890-
start.record(stream=stream)
891-
for _ in range(self.repeat):
892-
runner(inputs, tactic=tactic, **kwargs)
893-
end.record(stream=stream)
888+
with torch.cuda.stream(stream):
889+
# warm up, no timing
890+
for _ in range(self.warmup):
891+
runner(inputs, tactic=tactic, **kwargs)
892+
893+
if use_cuda_graph:
894+
with torch.cuda.graph(graph):
895+
for _ in range(self.repeat):
896+
runner(inputs, tactic=tactic, **kwargs)
897+
898+
stream.synchronize()
899+
900+
# Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
901+
# TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
902+
# Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
903+
if use_cuda_graph:
904+
delay_kernel(self._CUDA_GRAPH_DELAY_MICRO_SECS, stream)
905+
else:
906+
delay_kernel(self.stream_delay_micro_secs, stream)
907+
908+
start.record()
909+
910+
if use_cuda_graph:
911+
graph.replay()
912+
else:
913+
for _ in range(self.repeat):
914+
runner(inputs, tactic=tactic, **kwargs)
915+
916+
end.record()
917+
894918
stream.synchronize()
895919

896920
avg_time = start.elapsed_time(end) / self.repeat

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class CuteDSLNVFP4BlackwellLinear(TunableRunner):
3737
0, 0, get_last_power_of_2_num_tokens_buckets,
3838
last_positive_power_of_2), ),
3939
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
40+
use_cuda_graph=True,
4041
)
4142

4243
def __init__(self, alpha: float, output_dtype: torch.dtype):

tests/unittest/_torch/misc/test_autotuner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,19 +57,19 @@ def test_multi_dynamic_dims():
5757
# add sleep to simulate bad perf
5858
def gemm_0(x, w):
5959
if x.shape[0] > M // 2:
60-
delay_kernel(10000, torch.cuda.current_stream())
60+
delay_kernel(100, torch.cuda.current_stream())
6161
return x @ w
6262

6363

6464
def gemm_1(x, w):
6565
if x.shape[0] <= M // 2:
66-
delay_kernel(10000, torch.cuda.current_stream())
66+
delay_kernel(100, torch.cuda.current_stream())
6767
return x @ w
6868

6969

7070
def gemm_fallback(x, w) -> torch.Tensor:
7171
# always the slowest
72-
delay_kernel(100000, torch.cuda.current_stream())
72+
delay_kernel(500, torch.cuda.current_stream())
7373
return x @ w
7474

7575

0 commit comments

Comments
 (0)