-
Notifications
You must be signed in to change notification settings - Fork 372
Feat: Pre-quantized LLM model support #3740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
import json | ||
import logging | ||
import os | ||
|
||
import huggingface_hub | ||
import torch | ||
from huggingface_hub import snapshot_download | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
try: | ||
import modelopt.torch.quantization as mtq # noqa: F401f | ||
|
||
assert torch.ops.tensorrt.quantize_op.default | ||
except Exception: | ||
logger.warning("Unable to import quantization op. Please install modelopt library") | ||
|
||
from modelopt.core.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor | ||
from modelopt.torch.quantization.config import QuantizerAttributeConfig | ||
from modelopt.torch.quantization.nn.modules.tensor_quantizer import TensorQuantizer | ||
from modelopt.torch.utils.dataset_utils import ( | ||
create_forward_loop, | ||
get_dataset_dataloader, | ||
) | ||
from safetensors import safe_open | ||
|
||
|
||
def quantize_model(model, args, tokenizer): | ||
""" | ||
Quantize a PyTorch model using ModelOpt quantization. | ||
|
||
This function performs post-training quantization (PTQ) on the model using | ||
calibration data from the provided tokenizer. It supports both FP8 and NVFP4 | ||
quantization formats. | ||
|
||
Args: | ||
model: PyTorch model to quantize | ||
args: Arguments containing quantization format and debug settings | ||
tokenizer: Tokenizer for creating calibration dataloader | ||
|
||
Returns: | ||
Quantized model with reduced precision weights and activations | ||
|
||
Raises: | ||
RuntimeError: If unsupported quantization format is specified | ||
""" | ||
# Create calibration dataloader for quantization | ||
calib_dataloader = get_dataset_dataloader( | ||
tokenizer=tokenizer, | ||
batch_size=32, | ||
num_samples=512, | ||
device="cuda:0", | ||
) | ||
if args.qformat == "fp8": | ||
quant_cfg = mtq.FP8_DEFAULT_CFG | ||
elif args.qformat == "nvfp4": | ||
quant_cfg = mtq.NVFP4_DEFAULT_CFG | ||
else: | ||
raise RuntimeError("Unsupported quantization format") | ||
calibrate_loop = create_forward_loop(dataloader=calib_dataloader) | ||
|
||
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) | ||
if args.debug: | ||
mtq.print_quant_summary(model) | ||
|
||
return model | ||
|
||
|
||
class TensorRTQuantizedLinear(torch.nn.Module): | ||
""" | ||
TensorRT quantized linear layer that applies quantization to both input and weight tensors. | ||
""" | ||
|
||
def __init__( | ||
self, original_linear: torch.nn.Linear, input_amax, weight_amax, quant_cfg | ||
): | ||
""" | ||
Initialize quantized linear layer. | ||
|
||
Args: | ||
original_linear: Original PyTorch linear layer to quantize | ||
input_amax: Maximum absolute value for input quantization scaling | ||
weight_amax: Maximum absolute value for weight quantization scaling | ||
quant_cfg: Quantization configuration for TensorQuantizer | ||
""" | ||
super().__init__() | ||
|
||
# Store reference to original linear layer for weight access | ||
self.original_linear = original_linear | ||
|
||
# Copy bias from original layer if it exists | ||
if original_linear.bias is not None: | ||
self.bias = torch.nn.Parameter(original_linear.bias.clone()).cuda() | ||
else: | ||
self.bias = None | ||
|
||
# Create quantizers for input and weight tensors | ||
self.input_quantizer = TensorQuantizer( | ||
quant_attribute_cfg=quant_cfg, amax=input_amax | ||
) | ||
self.weight_quantizer = TensorQuantizer( | ||
quant_attribute_cfg=quant_cfg, amax=weight_amax | ||
) | ||
|
||
def forward(self, input): | ||
input = self.input_quantizer(input) | ||
weight = self.weight_quantizer(self.original_linear.weight) | ||
return torch.nn.functional.linear(input, weight, self.bias) | ||
|
||
|
||
def convert_linear_to_tensorrt_quantized(model, model_name): | ||
""" | ||
Convert linear layers in a model to TensorRT quantized versions from pre-quantized weights. | ||
|
||
This function is specifically designed for Hugging Face quantized models and only | ||
applies quantization to linear operations. It loads pre-quantized models from | ||
Hugging Face format and replaces standard linear layers with TensorRTQuantizedLinear | ||
layers. It supports both FP8 and NVFP4 quantization formats. | ||
|
||
The function: | ||
1. Loads quantization scales from Hugging Face model files (SafeTensors) | ||
2. Parses quantization configuration from hf_quant_config.json | ||
3. Replaces standard linear layers with TensorRTQuantizedLinear layers | ||
4. Applies appropriate quantization based on the model's quantization format | ||
|
||
Note: This function only quantizes linear operations and is intended for use | ||
with pre-quantized Hugging Face models that have been quantized using ModelOpt. | ||
|
||
Args: | ||
model: PyTorch model to quantize | ||
model_name: Path to Hugging Face model directory or model identifier | ||
|
||
Returns: | ||
Model with quantized linear layers | ||
|
||
Raises: | ||
RuntimeError: If quantization config is not found or unsupported format | ||
""" | ||
# Determine if model_name is a local directory or needs to be downloaded | ||
if os.path.isdir(model_name): | ||
hf_folder = model_name | ||
else: | ||
# Download model from Hugging Face Hub | ||
hf_folder = snapshot_download( | ||
model_name, | ||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, | ||
ignore_patterns=["original/**/*"], | ||
revision=None, | ||
) | ||
|
||
# Load all tensors from SafeTensors files | ||
tensors = {} | ||
for file in os.listdir(hf_folder): | ||
if file.endswith(".safetensors"): | ||
with safe_open( | ||
os.path.join(hf_folder, file), framework="pt", device="cpu" | ||
) as f: | ||
tensor_names = f.keys() | ||
for name in tensor_names: | ||
tensors[name] = f.get_tensor(name) | ||
|
||
# Load and parse quantization configuration | ||
hf_quant_config_path = f"{hf_folder}/hf_quant_config.json" | ||
if os.path.exists(hf_quant_config_path): | ||
with open(hf_quant_config_path, "r") as f: | ||
hf_quant_config = json.load(f) | ||
hf_quant_config = hf_quant_config["quantization"] | ||
|
||
hf_quant_algo = hf_quant_config.pop("quant_algo", None) | ||
if hf_quant_algo != "FP8" and hf_quant_algo != "NVFP4": | ||
raise RuntimeError("Only FP8 or NVFP4 quantization is supported") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would it be different for MXFP4? |
||
else: | ||
raise RuntimeError("No quantization config found") | ||
|
||
# Iterate through all modules in the model | ||
for name, module in model.named_modules(): | ||
# Check if the module is a linear layer | ||
target = torch.nn.modules.linear.Linear | ||
if isinstance(module, target): | ||
# Construct names for quantization scale tensors | ||
# These follow the naming convention: module_name.weight_scale and module_name.input_scale | ||
weight_scale_name = name + ".weight_scale" | ||
input_scale_name = name + ".input_scale" | ||
|
||
if weight_scale_name not in tensors: | ||
logger.warning(f"Weight scale tensor {weight_scale_name} not found") | ||
continue | ||
if input_scale_name not in tensors: | ||
logger.warning(f"Input scale tensor {input_scale_name} not found") | ||
continue | ||
|
||
if hf_quant_algo == "FP8": | ||
# FP8 E4M3 format has a maximum representable value of 448.0 | ||
# Scale the quantization parameters accordingly | ||
weight_scale = tensors.pop(weight_scale_name) | ||
weight_amax = weight_scale * 448.0 | ||
input_amax = tensors.pop(input_scale_name) * 448.0 | ||
|
||
# Dequantize the weight using the scale factor | ||
dequantized_weight_data = module.weight.to(torch.float32) * weight_scale | ||
|
||
# Configure quantizer for FP8 format (4 exponent bits, 3 mantissa bits) | ||
quantizer_attribute_config = QuantizerAttributeConfig( | ||
num_bits=(4, 3), axis=None | ||
) | ||
|
||
elif hf_quant_algo == "NVFP4": | ||
# NVFP4 format requires additional scale tensor and different configuration | ||
weight_name = name + ".weight" | ||
weight_scale2_name = name + ".weight_scale_2" | ||
weight_scale = tensors.pop(weight_scale_name) | ||
input_scale = tensors.pop(input_scale_name) | ||
weight_scale2 = tensors.pop(weight_scale2_name) | ||
|
||
# Calculate amax values with additional scaling factor for NVFP4 | ||
input_amax = input_scale * 448.0 * 6.0 | ||
weight_amax = weight_scale2 * 448.0 * 6.0 | ||
|
||
# Handle NVFP4 tensor format | ||
weight_data = tensors.pop(weight_name) | ||
original_shape = list(weight_data.shape) | ||
original_shape[-1] *= 2 # NVFP4 packs 2 values per element | ||
nvfp4_tensor = NVFP4QTensor( | ||
torch.Size(original_shape), torch.float32, weight_data | ||
) | ||
|
||
# Dequantize using both scales and block size configuration | ||
dequantized_weight_data = nvfp4_tensor.dequantize( | ||
scale=weight_scale, double_scale=weight_scale2, block_sizes={-1: 16} | ||
) | ||
|
||
# Configure quantizer for NVFP4 format with dynamic block quantization | ||
quantizer_attribute_config = QuantizerAttributeConfig( | ||
num_bits=(2, 1), | ||
axis=None, | ||
block_sizes={-1: 16, "type": "dynamic", "scale_bits": (4, 3)}, | ||
enable=True, | ||
) | ||
|
||
# Restore the weight to its original full-precision format so that QDQ nodes | ||
# can be properly inserted and optimized during TensorRT compilation | ||
module.weight.data = dequantized_weight_data | ||
|
||
# Create the quantized linear layer with calculated amax values | ||
quantized_module = TensorRTQuantizedLinear( | ||
module, input_amax, weight_amax, quantizer_attribute_config | ||
) | ||
|
||
# Replace the original module with the quantized version | ||
# Extract parent module name and child module name | ||
parent_name = ".".join(name.split(".")[:-1]) | ||
child_name = name.split(".")[-1] | ||
|
||
if parent_name: | ||
# Get the parent module and replace the child | ||
parent_module = model.get_submodule(parent_name) | ||
setattr(parent_module, child_name, quantized_module) | ||
else: | ||
# If no parent, replace at model level | ||
setattr(model, child_name, quantized_module) | ||
|
||
# Log any unused tensors for debugging | ||
if len(tensors) > 0: | ||
logger.debug(f"{len(tensors)} tensors not used") | ||
for key in tensors: | ||
logger.debug(f" {key}") | ||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,7 @@ | |
|
||
import argparse | ||
import copy | ||
import json | ||
import os | ||
import timeit | ||
from contextlib import nullcontext | ||
|
@@ -54,10 +55,13 @@ def get_model(args): | |
args.model, | ||
use_cache=False, | ||
attn_implementation="sdpa", | ||
ignore_mismatched_sizes=True, | ||
) | ||
.eval() | ||
.cuda() | ||
) | ||
if args.pre_quantized: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this something we could determine automatically? |
||
model = convert_linear_to_tensorrt_quantized(model, args.model).cuda() | ||
|
||
if args.precision == "FP16": | ||
model = model.to(torch.float16) | ||
|
@@ -91,7 +95,8 @@ def compile_torchtrt(model, input_ids, args): | |
for optimized inference | ||
""" | ||
max_seq_len = input_ids.shape[1] + args.num_tokens | ||
ep = export_llm(model, input_ids, max_seq_len=max_seq_len) | ||
with export_torch_mode() if args.qformat or args.pre_quantized else nullcontext(): | ||
ep = export_llm(model, input_ids, max_seq_len=max_seq_len) | ||
position_ids = torch.arange(input_ids.shape[1]).unsqueeze(0).to(DEVICE) | ||
# Set precision specific flags | ||
use_fp32_acc = False | ||
|
@@ -234,13 +239,36 @@ def measure_perf(trt_model, input_signature, backend_name): | |
arg_parser.add_argument( | ||
"--benchmark", action="store_true", help="Enable benchmark (default: False)" | ||
) | ||
|
||
arg_parser.add_argument( | ||
"--qformat", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the sake of clarity, this should be something like |
||
help=("Apply quantization format. Options: fp8, nvfp4 (default: None)"), | ||
default=None, | ||
) | ||
arg_parser.add_argument( | ||
"--pre_quantized", | ||
action="store_true", | ||
help="Use pre-quantized hf model weights (default: False)", | ||
) | ||
args = arg_parser.parse_args() | ||
|
||
if args.qformat and args.pre_quantized: | ||
print("Error: --qformat and --pre_quantized cannot be used together") | ||
exit() | ||
|
||
if args.qformat or args.pre_quantized: | ||
from modelopt.torch.quantization.utils import export_torch_mode | ||
from quantize_utils import ( | ||
convert_linear_to_tensorrt_quantized, | ||
quantize_model, | ||
) | ||
|
||
with torch.inference_mode(): | ||
model = get_model(args) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model) | ||
|
||
# Set pad token | ||
if tokenizer.pad_token is None: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
# Prepare input for benchmarking or evaluation | ||
if args.benchmark: | ||
input_ids = torch.randint( | ||
|
@@ -258,6 +286,8 @@ def measure_perf(trt_model, input_signature, backend_name): | |
pyt_timings = None | ||
pyt_stats = None | ||
|
||
if args.qformat != None: | ||
model = quantize_model(model, args, tokenizer) | ||
if args.enable_pytorch_run: | ||
pyt_gen_tokens = generate( | ||
model, input_ids.clone(), MAX_OUTPUT_SEQ_LENGTH, tokenizer.eos_token_id | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@peri044 Is this something we might want to upstream to ModelOpt in the future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or pull into main torch-tensorrt as a pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess its somewhat HF specific, so remaining in this tool would make sense but are there some parts we could make generic for any sort of quantization workflow (e.g. torchao)?