@@ -49,7 +49,6 @@ def get_model(args):
49
49
torch.nn.Module: The loaded and configured model ready for inference,
50
50
moved to CUDA device with the specified precision
51
51
"""
52
-
53
52
with torch .no_grad ():
54
53
model = (
55
54
AutoModelForCausalLM .from_pretrained (
@@ -112,23 +111,7 @@ def compile_torchtrt(model, input_ids, args):
112
111
else :
113
112
enabled_precisions = {torch .float32 }
114
113
115
- qformat = "_q_" + args .qformat if args .qformat else ""
116
-
117
- logging_dir = f"./{ args .model } _{ args .precision } { qformat } "
118
- # with torch_tensorrt.logging.debug() if args.debug else nullcontext():
119
- with (
120
- torch_tensorrt .dynamo .Debugger (
121
- "debug" ,
122
- logging_dir = logging_dir ,
123
- # capture_fx_graph_after=["constant_fold"],
124
- # save_engine_profile=True,
125
- # profile_format="trex",
126
- engine_builder_monitor = False ,
127
- # save_layer_info=True,
128
- )
129
- if args .debug
130
- else nullcontext ()
131
- ):
114
+ with torch_tensorrt .logging .debug () if args .debug else nullcontext ():
132
115
trt_model = torch_tensorrt .dynamo .compile (
133
116
ep ,
134
117
inputs = [input_ids , position_ids ],
@@ -151,14 +134,12 @@ def print_outputs(backend_name, gen_tokens, tokenizer):
151
134
"""
152
135
Print the generated tokens from the model.
153
136
"""
154
- out = tokenizer .decode (gen_tokens [0 ], skip_special_tokens = True )
155
137
print (f"========= { backend_name } =========" )
156
138
print (
157
139
f"{ backend_name } model generated text: " ,
158
- out ,
140
+ tokenizer . decode ( gen_tokens [ 0 ], skip_special_tokens = True ) ,
159
141
)
160
142
print ("===================================" )
161
- return out
162
143
163
144
164
145
def measure_perf (trt_model , input_signature , backend_name ):
@@ -260,13 +241,13 @@ def measure_perf(trt_model, input_signature, backend_name):
260
241
)
261
242
arg_parser .add_argument (
262
243
"--qformat" ,
263
- help = ("Apply quantization format. Options: fp8 (default: None)" ),
244
+ help = ("Apply quantization format. Options: fp8, nvfp4 (default: None)" ),
264
245
default = None ,
265
246
)
266
247
arg_parser .add_argument (
267
248
"--pre_quantized" ,
268
249
action = "store_true" ,
269
- help = "Use pre-quantized model weights (default: False)" ,
250
+ help = "Use pre-quantized hf model weights (default: False)" ,
270
251
)
271
252
args = arg_parser .parse_args ()
272
253
@@ -300,6 +281,7 @@ def measure_perf(trt_model, input_signature, backend_name):
300
281
pyt_gen_tokens = None
301
282
pyt_timings = None
302
283
pyt_stats = None
284
+
303
285
if args .qformat != None :
304
286
model = quantize_model (model , args , tokenizer )
305
287
if args .enable_pytorch_run :
@@ -380,43 +362,19 @@ def measure_perf(trt_model, input_signature, backend_name):
380
362
batch_size = args .batch_size ,
381
363
compile_time_s = None ,
382
364
)
383
- match_result = "N/A"
384
- torch_out = "N/A"
385
- model_name = args .model .replace ("/" , "_" )
386
- qformat = args .qformat if args .qformat else "no_quant"
387
365
388
366
if not args .benchmark :
389
367
if args .enable_pytorch_run :
390
- torch_out = print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
368
+ print_outputs ("PyTorch" , pyt_gen_tokens , tokenizer )
391
369
392
- trt_out = print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
370
+ print_outputs ("TensorRT" , trt_gen_tokens , tokenizer )
393
371
394
372
if args .enable_pytorch_run :
395
373
print (
396
374
f"PyTorch and TensorRT outputs match: { torch .equal (pyt_gen_tokens , trt_gen_tokens )} "
397
375
)
398
- match_result = str (torch .equal (pyt_gen_tokens , trt_gen_tokens ))
399
- out_json_file = f"{ model_name } _{ qformat } _match.json"
400
- result = {}
401
- args_dict = vars (args )
402
- result ["args" ] = args_dict
403
- result ["match" ] = match_result
404
- result ["torch_out" ] = torch_out
405
- result ["trt_out" ] = trt_out
406
- with open (os .path .join ("result" , out_json_file ), "w" ) as f :
407
- json .dump (result , f , indent = 4 )
408
- print (f"Results saved to { out_json_file } " )
376
+
409
377
if args .benchmark :
410
- result = {}
411
- args_dict = vars (args )
412
-
413
- result ["args" ] = args_dict
414
- result ["pyt_stats" ] = pyt_stats if args .enable_pytorch_run else None
415
- result ["trt_stats" ] = trt_stats if args .benchmark else None
416
- out_json_file = f"{ model_name } _{ qformat } _benchmark.json"
417
- with open (os .path .join ("result" , out_json_file ), "w" ) as f :
418
- json .dump (result , f , indent = 4 )
419
- print (f"Results saved to { out_json_file } " )
420
378
if args .enable_pytorch_run :
421
379
print ("=========PyTorch PERFORMANCE============ \n " )
422
380
print (pyt_stats )
0 commit comments