@@ -99,6 +99,7 @@ class TuningConfig:
9999 constraint_specs : Tuple [ConstraintSpec , ...] = ()
100100 tune_max_num_tokens : int = None
101101 inputs_pre_hook : Callable = None
102+ use_cuda_graph : bool = False
102103
103104
104105@dataclass (unsafe_hash = True )
@@ -522,6 +523,7 @@ class AutoTuner:
522523 repeat (int): Number of profiling iterations for averaging (default: 10)
523524 stream_delay_micro_secs (int): Delay on CUDA stream before the profiled kernel runs in microseconds (default: 1000)
524525 """
526+ _CUDA_GRAPH_DELAY_MICRO_SECS = 100
525527 _instance = None
526528
527529 def __init__ (self , warmup = 3 , repeat = 10 , stream_delay_micro_secs = 1000 ):
@@ -534,8 +536,6 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
534536 # Add statistics tracking
535537 self .stats = AutoTunerStatistics ()
536538
537- self .profiling_debug = True
538-
539539 # Current captured choose_one() contexts
540540 self ._active_capture : Optional ['AutoTuner.TacticsCapture' ] = None
541541 # Last captured choose_one() contexts
@@ -727,10 +727,10 @@ def choose_one(
727727 new_tuning_failure_occured = False
728728
729729 for p in profiles :
730+ tensors = self ._prepare_input_tensors (p , inputs )
730731 is_cache_hit , * _ = self .profiling_cache .search_cache (
731732 custom_op , runners , p .get_opt_shapes (), tuning_config )
732733 if not is_cache_hit :
733- tensors = self ._prepare_input_tensors (p , inputs )
734734 # Initialize runner and tactic as None in case of no valid tactic or runners are found
735735 best_runner_id , best_tactic , min_time , has_tuning_failure_occured = self ._profile_runners (
736736 custom_op , runners , tensors , p , tuning_config , ** kwargs )
@@ -811,7 +811,12 @@ def _profile_runners(
811811 for tac in valid_tactics :
812812 try :
813813 time_measured = self ._profile_single_kernel (
814- runner , input_tensors , tac , ** kwargs )
814+ runner = runner ,
815+ inputs = input_tensors ,
816+ tactic = tac ,
817+ use_cuda_graph = tuning_config .use_cuda_graph ,
818+ ** kwargs ,
819+ )
815820 except Exception as e :
816821 # Handle None tensors for optional inputs
817822 shapes = self ._get_input_sizes (input_tensors )
@@ -857,6 +862,7 @@ def _profile_single_kernel(
857862 runner : TunableRunner ,
858863 inputs : List [torch .Tensor ],
859864 tactic : Any ,
865+ use_cuda_graph : bool = False ,
860866 ** kwargs ,
861867 ) -> float :
862868 """Profile a single kernel implementation for performance measurement.
@@ -875,22 +881,40 @@ def _profile_single_kernel(
875881 are used to ensure accurate timing.
876882 """
877883 stream = torch .cuda .current_stream ()
878- # warm up, no timing
879- for _ in range (self .warmup ):
880- runner (inputs , tactic = tactic , ** kwargs )
881- stream .synchronize ()
882-
883- # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
884- # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
885- # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
886- delay_kernel (self .stream_delay_micro_secs , stream )
884+ graph = torch .cuda .CUDAGraph ()
887885 start = torch .cuda .Event (enable_timing = True )
888886 end = torch .cuda .Event (enable_timing = True )
889887
890- start .record (stream = stream )
891- for _ in range (self .repeat ):
892- runner (inputs , tactic = tactic , ** kwargs )
893- end .record (stream = stream )
888+ with torch .cuda .stream (stream ):
889+ # warm up, no timing
890+ for _ in range (self .warmup ):
891+ runner (inputs , tactic = tactic , ** kwargs )
892+
893+ if use_cuda_graph :
894+ with torch .cuda .graph (graph ):
895+ for _ in range (self .repeat ):
896+ runner (inputs , tactic = tactic , ** kwargs )
897+
898+ stream .synchronize ()
899+
900+ # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
901+ # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
902+ # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
903+ if use_cuda_graph :
904+ delay_kernel (self ._CUDA_GRAPH_DELAY_MICRO_SECS , stream )
905+ else :
906+ delay_kernel (self .stream_delay_micro_secs , stream )
907+
908+ start .record ()
909+
910+ if use_cuda_graph :
911+ graph .replay ()
912+ else :
913+ for _ in range (self .repeat ):
914+ runner (inputs , tactic = tactic , ** kwargs )
915+
916+ end .record ()
917+
894918 stream .synchronize ()
895919
896920 avg_time = start .elapsed_time (end ) / self .repeat
0 commit comments