88from tensorrt_llm ._torch .autotuner import AutoTuner , autotune
99from tensorrt_llm ._torch .modules .multi_stream_utils import with_multi_stream
1010from tensorrt_llm ._utils import local_mpi_rank , mpi_rank , mpi_world_size
11- from tensorrt_llm .tools .layer_wise_benchmarks .deepseekv3_runner import (
12- BalanceMethod , DeepSeekV3Runner )
11+ from tensorrt_llm .tools .layer_wise_benchmarks import BalanceMethod , get_runner_cls
1312
1413
1514def comma_separated_ints (s ):
@@ -23,30 +22,25 @@ def comma_separated_ints(s):
2322parser .add_argument (
2423 "--layer-indices" ,
2524 type = comma_separated_ints ,
26- help = "Comma separated indices of layers, should be a contiguous range" )
25+ help = "Comma separated indices of layers, should be a contiguous range" ,
26+ )
2727parser .add_argument ("--run-type" , type = str , choices = ["CTX" , "GEN" ])
2828parser .add_argument ("--scaled-from" , type = int )
2929# KV cache related args
30+ parser .add_argument ("--max-batch-size" , type = int )
3031parser .add_argument ("--tokens-per-block" , type = int )
3132parser .add_argument ("--max-seq-len" , type = int )
3233group = parser .add_mutually_exclusive_group (required = False )
33- group .add_argument ("--enable-attention-dp" ,
34- action = "store_true" ,
35- dest = "enable_attention_dp" )
36- group .add_argument ("--no-enable-attention-dp" ,
37- action = "store_false" ,
38- dest = "enable_attention_dp" )
34+ group .add_argument ("--enable-attention-dp" , action = "store_true" , dest = "enable_attention_dp" )
35+ group .add_argument ("--no-enable-attention-dp" , action = "store_false" , dest = "enable_attention_dp" )
3936parser .set_defaults (enable_attention_dp = None )
4037# Model init args
4138parser .add_argument ("--max-num-tokens" , type = int )
4239parser .add_argument ("--moe-backend" , type = str )
40+ parser .add_argument ("--moe-max-num-tokens" , type = int )
4341group = parser .add_mutually_exclusive_group (required = False )
44- group .add_argument ("--use-cuda-graph" ,
45- action = "store_true" ,
46- dest = "use_cuda_graph" )
47- group .add_argument ("--no-use-cuda-graph" ,
48- action = "store_false" ,
49- dest = "use_cuda_graph" )
42+ group .add_argument ("--use-cuda-graph" , action = "store_true" , dest = "use_cuda_graph" )
43+ group .add_argument ("--no-use-cuda-graph" , action = "store_false" , dest = "use_cuda_graph" )
5044parser .set_defaults (use_cuda_graph = None )
5145# Per iteration args
5246parser .add_argument ("--batch-size" , type = int )
@@ -59,8 +53,12 @@ def comma_separated_ints(s):
5953 config = yaml .safe_load (f )
6054del args .config_path
6155for k , v in vars (args ).items ():
62- if v is None :
56+ if v is None and k in config :
6357 setattr (args , k , config [k ])
58+ if args .max_batch_size is None :
59+ args .max_batch_size = args .batch_size
60+ if args .max_num_tokens is None :
61+ args .max_num_tokens = args .max_batch_size * args .seq_len_q
6462print (args )
6563
6664# MPI args
@@ -70,43 +68,49 @@ def comma_separated_ints(s):
7068torch .cuda .set_device (local_rank )
7169
7270# Create KV cache manager
73- mapping = DeepSeekV3Runner .create_mapping (
74- enable_attention_dp = args .enable_attention_dp )
75- max_batch_size = 2048
76- kv_cache_manager = DeepSeekV3Runner .create_kv_cache_manager (
71+ Runner = get_runner_cls (args .model )
72+ mapping = Runner .create_mapping (enable_attention_dp = args .enable_attention_dp )
73+ kv_cache_manager = Runner .create_kv_cache_manager (
7774 args .model ,
7875 mapping ,
7976 tokens_per_block = args .tokens_per_block ,
80- max_batch_size = max_batch_size ,
77+ max_batch_size = args . max_batch_size ,
8178 max_seq_len = args .max_seq_len ,
82- layer_indices = args .layer_indices )
83- attn_workspace = torch .empty ((0 , ), device = "cuda" , dtype = torch .int8 )
79+ layer_indices = args .layer_indices ,
80+ )
81+ attn_workspace = torch .empty ((0 ,), device = "cuda" , dtype = torch .int8 )
8482
8583# Create other global objects
8684AutoTuner .get ().clear_cache ()
8785capture_stream = torch .cuda .Stream ()
8886
8987# Create Runner
90- runner = DeepSeekV3Runner (args .model ,
91- mapping ,
92- moe_backend = args .moe_backend ,
93- layer_indices = args .layer_indices ,
94- scaled_from = args .scaled_from ,
95- max_seq_len = args .max_seq_len ,
96- max_num_tokens = args .max_num_tokens ,
97- use_cuda_graph = args .use_cuda_graph )
88+ runner = Runner (
89+ args .model ,
90+ mapping ,
91+ moe_backend = args .moe_backend ,
92+ layer_indices = args .layer_indices ,
93+ scaled_from = args .scaled_from ,
94+ max_seq_len = args .max_seq_len ,
95+ max_num_tokens = args .max_num_tokens ,
96+ moe_max_num_tokens = args .moe_max_num_tokens ,
97+ use_cuda_graph = args .use_cuda_graph ,
98+ )
9899
99100# Warm up
100- assert args .batch_size <= max_batch_size
101+ assert args .batch_size <= args . max_batch_size
101102assert args .seq_len_q + args .seq_len_kv_cache <= args .max_seq_len
102- run_pack = runner .create_run_pack (args .run_type ,
103- batch_size = args .batch_size ,
104- seq_len_q = args .seq_len_q ,
105- seq_len_kv_cache = args .seq_len_kv_cache ,
106- kv_cache_manager = kv_cache_manager ,
107- attn_workspace = attn_workspace )
108- runner .replace_routing_method (balance_method = BalanceMethod [args .balance_method ],
109- balance_ratio = args .balance_ratio )
103+ run_pack = runner .create_run_pack (
104+ args .run_type ,
105+ batch_size = args .batch_size ,
106+ seq_len_q = args .seq_len_q ,
107+ seq_len_kv_cache = args .seq_len_kv_cache ,
108+ kv_cache_manager = kv_cache_manager ,
109+ attn_workspace = attn_workspace ,
110+ )
111+ runner .replace_routing_method (
112+ balance_method = BalanceMethod [args .balance_method ], balance_ratio = args .balance_ratio
113+ )
110114capture_stream .wait_stream (torch .cuda .current_stream ())
111115with torch .cuda .stream (capture_stream ):
112116 run_pack ()
@@ -120,21 +124,15 @@ def comma_separated_ints(s):
120124if args .use_cuda_graph :
121125 with with_multi_stream (True ):
122126 g = torch .cuda .CUDAGraph ()
123- with torch .cuda .graph (g ,
124- stream = capture_stream ,
125- capture_error_mode = "global" ):
127+ with torch .cuda .graph (g , stream = capture_stream , capture_error_mode = "global" ):
126128 run_pack ()
127129
128130warmup_times = 20
129131run_times = 100
130- events = [
131- torch .cuda .Event (enable_timing = True )
132- for _ in range (warmup_times + run_times + 1 )
133- ]
132+ events = [torch .cuda .Event (enable_timing = True ) for _ in range (warmup_times + run_times + 1 )]
134133for i in range (warmup_times + run_times ):
135134 events [i ].record ()
136- with nvtx .annotate (
137- f"b={ args .batch_size } s={ args .seq_len_q } EP{ world_size } " ):
135+ with nvtx .annotate (f"b={ args .batch_size } s={ args .seq_len_q } EP{ world_size } " ):
138136 if args .use_cuda_graph :
139137 g .replay ()
140138 else :
@@ -144,16 +142,16 @@ def comma_separated_ints(s):
144142
145143# Print statistics
146144# Print before `cudaProfilerStop` to ensure messages are included in the profile
147- time_list = [
148- start .elapsed_time (stop ) for start , stop in zip (events , events [1 :])
149- ]
145+ time_list = [start .elapsed_time (stop ) for start , stop in zip (events , events [1 :])]
150146time_list = time_list [warmup_times :]
151- print (f"[RANK { rank } ]"
152- f" min { np .min (time_list ) * 1000 :.1f} "
153- f" max { np .max (time_list ) * 1000 :.1f} "
154- f" mean { np .mean (time_list ) * 1000 :.1f} "
155- f" median { np .median (time_list ) * 1000 :.1f} "
156- f" P90 { np .percentile (time_list , 90 ) * 1000 :.1f} "
157- f" (us)" )
147+ print (
148+ f"[RANK { rank } ]"
149+ f" min { np .min (time_list ) * 1000 :.1f} "
150+ f" max { np .max (time_list ) * 1000 :.1f} "
151+ f" mean { np .mean (time_list ) * 1000 :.1f} "
152+ f" median { np .median (time_list ) * 1000 :.1f} "
153+ f" P90 { np .percentile (time_list , 90 ) * 1000 :.1f} "
154+ f" (us)"
155+ )
158156
159157torch .cuda .cudart ().cudaProfilerStop ()
0 commit comments