Skip to content

Commit df34e8b

Browse files
committed
chore: clean up
1 parent 0b0c3fc commit df34e8b

File tree

3 files changed

+14
-63
lines changed

3 files changed

+14
-63
lines changed

tools/llm/quantize_utils.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,6 @@ def quantize_model(model, args, tokenizer):
6868
class TensorRTQuantizedLinear(torch.nn.Module):
6969
"""
7070
TensorRT quantized linear layer that applies quantization to both input and weight tensors.
71-
72-
This class implements a quantized linear layer that:
73-
1. Applies quantization to input tensor using TensorQuantizer
74-
2. Applies quantization to weight tensor using TensorQuantizer
75-
3. Performs linear operation with quantized tensors
7671
"""
7772

7873
def __init__(
@@ -114,7 +109,7 @@ def forward(self, input):
114109

115110
def convert_linear_to_tensorrt_quantized(model, model_name):
116111
"""
117-
Convert linear layers in a model to TensorRT quantized versions using pre-quantized weights.
112+
Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights.
118113
119114
This function is specifically designed for Hugging Face quantized models and only
120115
applies quantization to linear operations. It loads pre-quantized models from
@@ -172,7 +167,7 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
172167

173168
hf_quant_algo = hf_quant_config.pop("quant_algo", None)
174169
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4":
175-
raise RuntimeError("Only FP8 and NVFP4 quantization is supported")
170+
raise RuntimeError("Only FP8 or NVFP4 quantization is supported")
176171
else:
177172
raise RuntimeError("No quantization config found")
178173

@@ -186,7 +181,6 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
186181
weight_scale_name = name + ".weight_scale"
187182
input_scale_name = name + ".input_scale"
188183

189-
# Verify that required scale tensors exist in the loaded data
190184
if weight_scale_name not in tensors:
191185
print(f"Weight scale tensor {weight_scale_name} not found")
192186
continue
@@ -202,7 +196,7 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
202196
input_amax = tensors.pop(input_scale_name) * 448.0
203197

204198
# Dequantize the weight using the scale factor
205-
dequantized_weight_data = module.weight.to(torch.float16) * weight_scale
199+
dequantized_weight_data = module.weight.to(torch.float32) * weight_scale
206200

207201
# Configure quantizer for FP8 format (4 exponent bits, 3 mantissa bits)
208202
quantizer_attribute_config = QuantizerAttributeConfig(
@@ -226,7 +220,7 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
226220
original_shape = list(weight_data.shape)
227221
original_shape[-1] *= 2 # NVFP4 packs 2 values per element
228222
nvfp4_tensor = NVFP4QTensor(
229-
torch.Size(original_shape), torch.float16, weight_data
223+
torch.Size(original_shape), torch.float32, weight_data
230224
)
231225

232226
# Dequantize using both scales and block size configuration
@@ -242,8 +236,8 @@ def convert_linear_to_tensorrt_quantized(model, model_name):
242236
enable=True,
243237
)
244238

245-
# Apply dequantization to the original quantized weight using the scale
246-
# This ensures the weight is in the correct range for the quantized layer
239+
# Restore the weight to its original full-precision format so that QDQ nodes
240+
# can be properly inserted and optimized during TensorRT compilation
247241
module.weight.data = dequantized_weight_data
248242

249243
# Create the quantized linear layer with calculated amax values

tools/llm/run_llm.py

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def get_model(args):
4949
torch.nn.Module: The loaded and configured model ready for inference,
5050
moved to CUDA device with the specified precision
5151
"""
52-
5352
with torch.no_grad():
5453
model = (
5554
AutoModelForCausalLM.from_pretrained(
@@ -112,23 +111,7 @@ def compile_torchtrt(model, input_ids, args):
112111
else:
113112
enabled_precisions = {torch.float32}
114113

115-
qformat = "_q_" + args.qformat if args.qformat else ""
116-
117-
logging_dir = f"./{args.model}_{args.precision}{qformat}"
118-
# with torch_tensorrt.logging.debug() if args.debug else nullcontext():
119-
with (
120-
torch_tensorrt.dynamo.Debugger(
121-
"debug",
122-
logging_dir=logging_dir,
123-
# capture_fx_graph_after=["constant_fold"],
124-
# save_engine_profile=True,
125-
# profile_format="trex",
126-
engine_builder_monitor=False,
127-
# save_layer_info=True,
128-
)
129-
if args.debug
130-
else nullcontext()
131-
):
114+
with torch_tensorrt.logging.debug() if args.debug else nullcontext():
132115
trt_model = torch_tensorrt.dynamo.compile(
133116
ep,
134117
inputs=[input_ids, position_ids],
@@ -151,14 +134,12 @@ def print_outputs(backend_name, gen_tokens, tokenizer):
151134
"""
152135
Print the generated tokens from the model.
153136
"""
154-
out = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
155137
print(f"========= {backend_name} =========")
156138
print(
157139
f"{backend_name} model generated text: ",
158-
out,
140+
tokenizer.decode(gen_tokens[0], skip_special_tokens=True),
159141
)
160142
print("===================================")
161-
return out
162143

163144

164145
def measure_perf(trt_model, input_signature, backend_name):
@@ -260,13 +241,13 @@ def measure_perf(trt_model, input_signature, backend_name):
260241
)
261242
arg_parser.add_argument(
262243
"--qformat",
263-
help=("Apply quantization format. Options: fp8 (default: None)"),
244+
help=("Apply quantization format. Options: fp8, nvfp4 (default: None)"),
264245
default=None,
265246
)
266247
arg_parser.add_argument(
267248
"--pre_quantized",
268249
action="store_true",
269-
help="Use pre-quantized model weights (default: False)",
250+
help="Use pre-quantized hf model weights (default: False)",
270251
)
271252
args = arg_parser.parse_args()
272253

@@ -300,6 +281,7 @@ def measure_perf(trt_model, input_signature, backend_name):
300281
pyt_gen_tokens = None
301282
pyt_timings = None
302283
pyt_stats = None
284+
303285
if args.qformat != None:
304286
model = quantize_model(model, args, tokenizer)
305287
if args.enable_pytorch_run:
@@ -380,43 +362,19 @@ def measure_perf(trt_model, input_signature, backend_name):
380362
batch_size=args.batch_size,
381363
compile_time_s=None,
382364
)
383-
match_result = "N/A"
384-
torch_out = "N/A"
385-
model_name = args.model.replace("/", "_")
386-
qformat = args.qformat if args.qformat else "no_quant"
387365

388366
if not args.benchmark:
389367
if args.enable_pytorch_run:
390-
torch_out = print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
368+
print_outputs("PyTorch", pyt_gen_tokens, tokenizer)
391369

392-
trt_out = print_outputs("TensorRT", trt_gen_tokens, tokenizer)
370+
print_outputs("TensorRT", trt_gen_tokens, tokenizer)
393371

394372
if args.enable_pytorch_run:
395373
print(
396374
f"PyTorch and TensorRT outputs match: {torch.equal(pyt_gen_tokens, trt_gen_tokens)}"
397375
)
398-
match_result = str(torch.equal(pyt_gen_tokens, trt_gen_tokens))
399-
out_json_file = f"{model_name}_{qformat}_match.json"
400-
result = {}
401-
args_dict = vars(args)
402-
result["args"] = args_dict
403-
result["match"] = match_result
404-
result["torch_out"] = torch_out
405-
result["trt_out"] = trt_out
406-
with open(os.path.join("result", out_json_file), "w") as f:
407-
json.dump(result, f, indent=4)
408-
print(f"Results saved to {out_json_file}")
376+
409377
if args.benchmark:
410-
result = {}
411-
args_dict = vars(args)
412-
413-
result["args"] = args_dict
414-
result["pyt_stats"] = pyt_stats if args.enable_pytorch_run else None
415-
result["trt_stats"] = trt_stats if args.benchmark else None
416-
out_json_file = f"{model_name}_{qformat}_benchmark.json"
417-
with open(os.path.join("result", out_json_file), "w") as f:
418-
json.dump(result, f, indent=4)
419-
print(f"Results saved to {out_json_file}")
420378
if args.enable_pytorch_run:
421379
print("=========PyTorch PERFORMANCE============ \n")
422380
print(pyt_stats)

tools/llm/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import timeit
32

43
import numpy as np

0 commit comments

Comments
 (0)