@@ -99,7 +99,7 @@ class TuningConfig:
9999 constraint_specs : Tuple [ConstraintSpec , ...] = ()
100100 tune_max_num_tokens : int = None
101101 inputs_pre_hook : Callable = None
102- use_cuda_graph : bool = False
102+ use_cuda_graph : bool = True
103103
104104
105105@dataclass (unsafe_hash = True )
@@ -526,7 +526,7 @@ class AutoTuner:
526526 _CUDA_GRAPH_DELAY_MICRO_SECS = 100
527527 _instance = None
528528
529- def __init__ (self , warmup = 3 , repeat = 10 , stream_delay_micro_secs = 1000 ):
529+ def __init__ (self , warmup = 2 , repeat = 10 , stream_delay_micro_secs = 1000 ):
530530 self .repeat = repeat
531531 self .warmup = warmup
532532 self .stream_delay_micro_secs = stream_delay_micro_secs
@@ -698,23 +698,25 @@ def choose_one(
698698 })
699699
700700 input_shapes = tuple (self ._get_input_sizes (inputs ))
701+ is_cache_hit , best_runner_id , best_tactic , min_time = self .profiling_cache .search_cache (
702+ custom_op , runners , input_shapes , tuning_config )
703+
701704 # Early return if it's not tuning, use cache found one or fallback one
702705 if not self .is_tuning_mode :
703- is_cache_hit , best_runner_id , best_tactic , min_time = self .profiling_cache .search_cache (
704- custom_op , runners , input_shapes , tuning_config )
705706 best_runner = runners [best_runner_id ]
706707 # TODO: check the stored runner and tactic can implement this shape here
707- # Should not directly try (runner, tactic) here, or it will hurt a lot of inference perf.
708-
709- # Record the cache miss config.
710- # Expect no cache miss in inference. Thus, any cache miss should be recorded.
708+ # Log the cache miss. Expect no cache miss in inference.
711709 if not is_cache_hit :
712710 logger .warning_once (
713711 f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={ input_shapes } " ,
714712 key = (custom_op , "warning_autotuning_cache_miss_fallback" ))
715713
716714 return (best_runner , best_tactic )
717715
716+ # If it's tuning mode and cache hit, return the best runner and tactic to avoid redundant profiling.
717+ if self .is_tuning_mode and is_cache_hit :
718+ return (runners [best_runner_id ], best_tactic )
719+
718720 assert len (runners ) > 0 , "At least one runner is required"
719721 assert all ([isinstance (r , TunableRunner ) for r in runners ]), \
720722 "All Given runners must be subclass of TunableRunner"
@@ -881,43 +883,62 @@ def _profile_single_kernel(
881883 are used to ensure accurate timing.
882884 """
883885 stream = torch .cuda .current_stream ()
884- graph = torch .cuda .CUDAGraph ()
885- start = torch .cuda .Event (enable_timing = True )
886- end = torch .cuda .Event (enable_timing = True )
887-
888- with torch .cuda .stream (stream ):
889- # warm up, no timing
890- for _ in range (self .warmup ):
891- runner (inputs , tactic = tactic , ** kwargs )
892-
893- if use_cuda_graph :
894- with torch .cuda .graph (graph ):
895- for _ in range (self .repeat ):
896- runner (inputs , tactic = tactic , ** kwargs )
886+ # If the warm up time is longer than 0.5ms, we will profile the kernel with fewer repeats.
887+ profile_fewer_repeat = 2
888+ short_profile_threshold_ms = 1
889+
890+ avg_time = float ('inf' )
891+
892+ def pure_profile (stream : torch .cuda .Stream , repeat : int ):
893+ start = torch .cuda .Event (enable_timing = True )
894+ end = torch .cuda .Event (enable_timing = True )
895+ graph = torch .cuda .CUDAGraph ()
896+
897+ with torch .cuda .stream (stream ):
898+ if use_cuda_graph :
899+ with torch .cuda .graph (graph ):
900+ for _ in range (repeat ):
901+ runner (inputs , tactic = tactic , ** kwargs )
902+
903+ stream .synchronize ()
904+
905+ # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
906+ # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
907+ # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
908+ if use_cuda_graph :
909+ delay_kernel (self ._CUDA_GRAPH_DELAY_MICRO_SECS , stream )
910+ else :
911+ delay_kernel (self .stream_delay_micro_secs , stream )
897912
898- stream . synchronize ()
913+ start . record ()
899914
900- # Delay the profiled kernel launch to eliminate affects of host time overhead in profiling.
901- # TODO: This is build time sensitive, O(tactic_num * impl_num * num_profile * tunable_ops)
902- # Consider apply a preprofiling to estimate the kernel execution time, then decide the necessity.
903- if use_cuda_graph :
904- delay_kernel (self ._CUDA_GRAPH_DELAY_MICRO_SECS , stream )
905- else :
906- delay_kernel (self .stream_delay_micro_secs , stream )
915+ if use_cuda_graph :
916+ graph .replay ()
917+ else :
918+ for _ in range (repeat ):
919+ runner (inputs , tactic = tactic , ** kwargs )
907920
908- start .record ()
921+ end .record ()
922+ stream .synchronize ()
909923
910- if use_cuda_graph :
911- graph .replay ()
912- else :
913- for _ in range (self .repeat ):
914- runner (inputs , tactic = tactic , ** kwargs )
924+ return start .elapsed_time (end ) / repeat
915925
916- end .record ()
926+ for _ in range (self .warmup ):
927+ runner (inputs , tactic = tactic , ** kwargs )
917928
918- stream . synchronize ( )
929+ fewer_repeat_avg_time = pure_profile ( stream , profile_fewer_repeat )
919930
920- avg_time = start .elapsed_time (end ) / self .repeat
931+ disable_short_profile = os .environ .get (
932+ "TLLM_AUTOTUNER_DISABLE_SHORT_PROFILE" , "0" ) == "1"
933+ if fewer_repeat_avg_time > short_profile_threshold_ms and not disable_short_profile :
934+ print (
935+ f"[Autotuner] Few repeat estimated time is longer than { short_profile_threshold_ms } ms, directly use the few repeat estimated time to avoid redundant profiling."
936+ )
937+ # directly use the few repeat estimated time to avoid redundant profiling
938+ avg_time = fewer_repeat_avg_time
939+ else :
940+ # profile the kernel with the full repeat to get precise time
941+ avg_time = pure_profile (stream , self .repeat )
921942
922943 shapes = self ._get_input_sizes (inputs )
923944 logger .debug (
0 commit comments