From 2847be1084b5a623c5041ea6e1347bdca2287999 Mon Sep 17 00:00:00 2001 From: Daniel Afrimi Date: Tue, 18 Nov 2025 13:13:05 +0000 Subject: [PATCH 1/5] new format Signed-off-by: Daniel Afrimi --- tensorrt_llm/_torch/model_config.py | 187 +------------- tensorrt_llm/llmapi/llm_utils.py | 104 +------- tensorrt_llm/models/modeling_utils.py | 311 ++++++++++++++++++++++++ tests/unittest/llmapi/test_llm_quant.py | 16 +- 4 files changed, 333 insertions(+), 285 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index b7e42fc09b0..f5357050cca 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -1,6 +1,4 @@ import contextlib -import json -import os import tempfile from dataclasses import dataclass, field from pathlib import Path @@ -14,7 +12,7 @@ from tensorrt_llm import logger from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid, load_pretrained_config) -from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding +from tensorrt_llm._utils import torch_dtype_to_binding from tensorrt_llm.bindings import LayerType as LayerTypeCpp from tensorrt_llm.functional import AllReduceStrategy from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig, @@ -22,7 +20,6 @@ from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.quantization.mode import QuantAlgo TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) @@ -219,173 +216,6 @@ def is_generation_model(model_architectures: Optional[List[str]], # TODO: should be 'not model_type == ModelType.ENCODER_ONLY' # once ModelType is used in pytorch flow. - @staticmethod - def load_modelopt_quant_config(quant_config_file, checkpoint_dir, - moe_backend): - quant_config = QuantConfig() - layer_quant_config = None - - with open(quant_config_file) as f: - quant_config_dict = json.load(f) - - json_quant_configs = quant_config_dict['quantization'] - - quant_config.quant_algo = json_quant_configs.get('quant_algo', None) - # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES - if quant_config.quant_algo == "fp8_pb_wo": - quant_config.quant_algo = 'FP8_BLOCK_SCALES' - quant_config.kv_cache_quant_algo = json_quant_configs.get( - 'kv_cache_quant_algo', None) - quant_config.group_size = json_quant_configs.get('group_size', None) - quant_config.exclude_modules = json_quant_configs.get( - 'exclude_modules', None) - - if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION: - json_extended_quant_configs: dict = {} - # See tests/unittest/llmapi/test_llm_quant.py - try: - mixed_quant_config_file = transformers.utils.hub.cached_file( - checkpoint_dir, 'quant_cfg.json') - with open(mixed_quant_config_file) as fm: - json_extended_quant_configs = json.load(fm) - except Exception: - logger.info( - f"No quant_cfg.json found for layer quant info, using hf_quant_config.json." - ) - json_quant_configs.update(json_extended_quant_configs) - # kv_cache_quant_algo is global regardless of MIXED_PRECISION - kv_cache_quant_algo = json_quant_configs.get( - 'kv_cache_quant_algo', None) - mixed_quant_configs = json_quant_configs.get( - 'quantized_layers', None) - if (kv_quant_lhs := json_extended_quant_configs.get( - "kv_cache_quant_algo", None)) is not None and ( - kv_quant_rhs := - quant_config.kv_cache_quant_algo) is not None: - if kv_quant_lhs != kv_quant_rhs: - raise RuntimeError( - f"The kvcache config in 'quant_cfg.json', {kv_quant_lhs}," - f"is different from 'hf_quant_config.json', {kv_quant_rhs}!" - ) - quant_config.kv_cache_quant_algo = json_quant_configs[ - "kv_cache_quant_algo"] - for layer in mixed_quant_configs: - config = QuantConfig() - config.kv_cache_quant_algo = kv_cache_quant_algo - config.quant_algo = mixed_quant_configs[layer]['quant_algo'] - config.group_size = mixed_quant_configs[layer].get( - 'group_size', None) - mixed_quant_configs[layer] = config - layer_quant_config = mixed_quant_configs - elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES: - if quant_config.group_size is None: - quant_config.group_size = 128 - - if moe_backend == 'TRTLLM' and quant_config.quant_algo == "FP8_BLOCK_SCALES" and quant_config.exclude_modules is None: - quant_config.exclude_modules = [ - "*kv_b_proj*", "*k_b_proj*", "*eh_proj" - ] - return quant_config, layer_quant_config - - @staticmethod - def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): - quant_algo = ModelConfig.override_quant_algo() - if quant_algo is None and not is_dynamic_quant: - if get_sm_version() >= 100: - if moe_backend == 'TRITON': - return QuantAlgo.W4A8_MXFP4_FP8 - else: - return QuantAlgo.W4A8_MXFP4_MXFP8 - else: - return QuantAlgo.W4A16_MXFP4 - else: - return quant_algo - - @staticmethod - def load_hf_quant_config(hf_quant_config, moe_backend): - quant_config = QuantConfig() - layer_quant_config = None - - # DeepSeek V3 FP8 ckpt - if hf_quant_config.get("quant_method") == "fp8" and hf_quant_config.get( - "weight_block_size", []): - quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES - if moe_backend == 'TRTLLM': - # TODO: This is a hack. Remove after fp8 bmm is integrated. - quant_config.exclude_modules = [ - "*kv_b_proj*", "*k_b_proj*", "*eh_proj" - ] - else: - quant_config.exclude_modules = ["*eh_proj"] - - block_size = hf_quant_config.get("weight_block_size", []) - assert tuple(block_size) == ( - 128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" - quant_config.group_size = block_size[0] - # MXFP4 checkpoints. - elif hf_quant_config.get("quant_method") == "mxfp4": - quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( - moe_backend) - quant_config.group_size = 32 - quant_config.exclude_modules = [ - 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', - 'embedding', 'unembedding' - ] - - return quant_config, layer_quant_config - - @staticmethod - def load_quant_config_from_dtypes_json(dtypes_json_file, moe_backend: str): - quant_config = QuantConfig() - layer_quant_config = None - - exclude_modules = set() - has_mxfp4 = False - is_dynamic_quant = False - with open(dtypes_json_file) as f: - dtypes_json = json.load(f) - for layer, dtype in dtypes_json.items(): - if layer.endswith("weight"): - if dtype == "BF16" or dtype == "FP16": - names = layer.split(".") - exclude_modules.add('.'.join(names[:-1])) - elif dtype == "MXFP4": - # This is the path for the fp8 checkpoint which requires dynamic quantization. - is_dynamic_quant = True - has_mxfp4 = True - elif layer.endswith("weight.blocks"): - scale_name = layer.replace("weight.blocks", "weight.scales") - scale_dtype = dtypes_json.get(scale_name, None) - assert scale_dtype == "UE8" - is_dynamic_quant = False - has_mxfp4 = True - - if has_mxfp4: - quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( - moe_backend, is_dynamic_quant) - quant_config.group_size = 32 - quant_config.exclude_modules = list(exclude_modules) - logger.info(f"Setting quant_config: {quant_config}") - - return quant_config, layer_quant_config - - @staticmethod - def override_quant_algo(): - new_algo = os.environ.get("OVERRIDE_QUANT_ALGO", None) - supported_algos = { - "W4A16_MXFP4": QuantAlgo.W4A16_MXFP4, - "W4A8_MXFP4_MXFP8": QuantAlgo.W4A8_MXFP4_MXFP8, - "W4A8_MXFP4_FP8": QuantAlgo.W4A8_MXFP4_FP8, - } - if new_algo is not None: - if new_algo.upper() in supported_algos: - return supported_algos[new_algo.upper()] - else: - logger.warning( - f"Unsupported quant algo: {new_algo}, supported algos: {supported_algos.keys()}" - ) - return None - @classmethod def from_pretrained(cls, checkpoint_dir: str, @@ -445,16 +275,17 @@ def cached_file(path_or_repo_id, file_name): # quantized ckpt in modelopt format if quant_config_file := cached_file(checkpoint_dir, 'hf_quant_config.json'): - quant_config, layer_quant_config = cls.load_modelopt_quant_config( - quant_config_file, checkpoint_dir, moe_backend) + is_quant_config_changed, layer_quant_config = quant_config.update_from_model_ckpt( + checkpoint_dir, moe_backend) + # quantized ckpt in other formats elif hasattr(pretrained_config, "quantization_config"): - hf_quant_config = pretrained_config.quantization_config - quant_config, layer_quant_config = cls.load_hf_quant_config( - hf_quant_config, moe_backend) + is_quant_config_changed, layer_quant_config = quant_config.update_from_model_ckpt( + checkpoint_dir, moe_backend) + elif quant_config_file := cached_file(checkpoint_dir, 'dtypes.json'): - quant_config, layer_quant_config = cls.load_quant_config_from_dtypes_json( - quant_config_file, moe_backend) + is_quant_config_changed, layer_quant_config = quant_config.update_from_model_ckpt( + checkpoint_dir, moe_backend) model_config = cls(pretrained_config=pretrained_config, quant_config=quant_config, diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index fc1647a8070..4b46b4f791c 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -25,7 +25,7 @@ from ..logger import logger from ..mapping import Mapping from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM -from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig +from ..models.modeling_utils import PretrainedConfig, QuantConfig from ..module import Module from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, get_build_cache_config_from_env) @@ -332,106 +332,10 @@ def _download_hf_model(self): assert self.speculative_model_obj.is_local_model def _update_from_hf_quant_config(self) -> bool: - """Update quant_config from the config file of pre-quantized HF checkpoint. - - Returns: - prequantized (bool): Whether the checkpoint is pre-quantized. - """ quant_config = self.llm_args.quant_config - - hf_quant_config_path = f"{self._model_dir}/hf_quant_config.json" - if os.path.exists(hf_quant_config_path): - logger.info( - f"Found {hf_quant_config_path}, pre-quantized checkpoint is used." - ) - with open(hf_quant_config_path, "r") as f: - hf_quant_config = json.load(f) - hf_quant_config = hf_quant_config["quantization"] - - hf_quant_algo = hf_quant_config.pop("quant_algo", None) - if hf_quant_algo is not None: - # fp8_pb_wo from modelopt is the same as fp8_block_scales - if hf_quant_algo == "fp8_pb_wo": - hf_quant_algo = QuantAlgo.FP8_BLOCK_SCALES - else: - hf_quant_algo = QuantAlgo(hf_quant_algo) - if quant_config.quant_algo is None: - logger.info( - f"Setting quant_algo={hf_quant_algo} form HF quant config." - ) - quant_config.quant_algo = hf_quant_algo - elif quant_config.quant_algo != hf_quant_algo: - raise ValueError( - f"Specified quant_algo={quant_config.quant_algo}, conflicting with quant_algo={hf_quant_algo} from HF quant config." - ) - else: - raise ValueError( - "Pre-quantized checkpoint must have quant_algo.") - - hf_kv_cache_quant_algo = hf_quant_config.pop( - "kv_cache_quant_algo", None) - if hf_kv_cache_quant_algo is not None: - hf_kv_cache_quant_algo = QuantAlgo(hf_kv_cache_quant_algo) - if quant_config.kv_cache_quant_algo is None: - logger.info( - f"Setting kv_cache_quant_algo={hf_kv_cache_quant_algo} form HF quant config." - ) - quant_config.kv_cache_quant_algo = hf_kv_cache_quant_algo - elif quant_config.kv_cache_quant_algo != hf_kv_cache_quant_algo: - raise ValueError( - f"Specified kv_cache_quant_algo={quant_config.kv_cache_quant_algo}, conflicting with kv_cache_quant_algo={hf_kv_cache_quant_algo} from HF quant config." - ) - else: - if quant_config.kv_cache_quant_algo not in [ - None, QuantAlgo.FP8, QuantAlgo.NVFP4 - ]: - raise ValueError( - f"Only kv_cache_quant_algo={QuantAlgo.FP8} or {QuantAlgo.NVFP4} is allowed for pre-quantized checkpoint, got {quant_config.kv_cache_quant_algo}." - ) - - for key, value in hf_quant_config.items(): - logger.info( - f"Setting {key}={str(value)[:100]}{'...' if len(str(value)) > 100 else ''} from HF quant config." - ) - setattr(quant_config, key, value) - - # Update the quant_config in llm_args for pytorch - self.llm_args.quant_config = quant_config - - return True - - hf_config_path = f"{self._model_dir}/config.json" - if os.path.exists(hf_config_path): - with open(hf_config_path, "r") as f: - hf_config = json.load(f) - hf_quant_config = hf_config.get("quantization_config", None) - - if hf_quant_config is not None: - logger.info( - f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used." - ) - # DeepSeek V3 FP8 ckpt - if hf_quant_config.get( - "quant_method") == "fp8" and hf_quant_config.get( - "weight_block_size"): - quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES - quant_config.exclude_modules = ["*eh_proj"] - elif hf_quant_config.get("quant_method") == "mxfp4": - from .._torch.model_config import ModelConfig - quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo( - self.llm_args.moe_config.backend) - quant_config.group_size = 32 - quant_config.exclude_modules = [ - 'block.*.attn.out', 'block.*.mlp.gate', - 'block.*.attn.qkv', 'embedding', 'unembedding' - ] - else: - raise NotImplementedError( - f"Unsupported quantization_config: {hf_quant_config}.") - - return True - - return False + quant_config_changed, _ = quant_config.update_from_model_ckpt( + self._model_dir, self.llm_args.moe_config.backend) + return quant_config_changed def _load_model_from_hf(self): ''' Load a TRT-LLM model from a HF model. ''' diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 03c1ee60ae5..bfccdcaa28f 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -15,6 +15,8 @@ import safetensors import torch +from tensorrt_llm._utils import get_sm_version + from .._common import default_net from .._utils import (QuantModeWrapper, get_init_params, numpy_to_torch, release_gc, str_dtype_to_torch, str_dtype_to_trt, @@ -196,6 +198,315 @@ def _requires_modelopt_quantization(self): else: return False + @staticmethod + def override_quant_algo(): + new_algo = os.environ.get("OVERRIDE_QUANT_ALGO", None) + supported_algos = { + "W4A16_MXFP4": QuantAlgo.W4A16_MXFP4, + "W4A8_MXFP4_MXFP8": QuantAlgo.W4A8_MXFP4_MXFP8, + "W4A8_MXFP4_FP8": QuantAlgo.W4A8_MXFP4_FP8, + } + if new_algo is not None: + if new_algo.upper() in supported_algos: + return supported_algos[new_algo.upper()] + else: + logger.warning( + f"Unsupported quant algo: {new_algo}, supported algos: {supported_algos.keys()}" + ) + return None + + @staticmethod + def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): + quant_algo = QuantConfig.override_quant_algo() + if quant_algo is None and not is_dynamic_quant: + if get_sm_version() >= 100: + if moe_backend == 'TRITON': + return QuantAlgo.W4A8_MXFP4_FP8 + else: + return QuantAlgo.W4A8_MXFP4_MXFP8 + else: + return QuantAlgo.W4A16_MXFP4 + else: + return quant_algo + + @staticmethod + def _infer_kv_cache_quant_algo_from_scheme(kv_scheme: dict) -> str | None: + kv_type = (kv_scheme.get("type") or "").lower() + bits = kv_scheme.get("num_bits") + dynamic = bool(kv_scheme.get("dynamic", False)) + + # todo add here all options... + if kv_type == "float" and bits == 8 and not dynamic: + return QuantAlgo("FP8_BLOCK_SCALES") + if kv_type in ("int", "uint") and bits == 8: + return QuantAlgo("INT8") + return None + + def _map_new_to_legacy_args(self, hf_quant_config: dict) -> dict: + qunatization_dict = {} + quant_algo = hf_quant_config.get("quant_algo") + if quant_algo == "fp8_pb_wo": + quant_algo = "FP8_BLOCK_SCALES" + if quant_algo is not None: + qunatization_dict["quant_algo"] = quant_algo + + if quant_algo == QuantAlgo.W4A16_AWQ or quant_algo == QuantAlgo.W4A8_AWQ: + qunatization_dict["pre_quant_scale"] = True + + if "group_size" in hf_quant_config: + qunatization_dict["group_size"] = hf_quant_config["group_size"] + + if "ignore" in hf_quant_config: + qunatization_dict["exclude_modules"] = list( + hf_quant_config.get("ignore") or []) + + kv_scheme = hf_quant_config.get("kv_cache_scheme") or {} + kv_algo = QuantConfig._infer_kv_cache_quant_algo_from_scheme( + kv_scheme) # todo check it + if kv_algo is not None: + qunatization_dict["kv_cache_quant_algo"] = kv_algo + + if "quantized_layers" in hf_quant_config: + qunatization_dict["quantized_layers"] = hf_quant_config[ + "quantized_layers"] + + if "symmetric" in hf_quant_config: + qunatization_dict["zero_point"] = hf_quant_config["symmetric"] + + # todo add here pre qunat scale and other keys.... + return qunatization_dict + + def _update_from_quant_config_json(self, path, moe_backend: str, + model_ckpt_path) -> bool: + with open(path, "r") as f: + hf_config = json.load(f) + hf_quant_config = hf_config.get("quantization_config", None) + + if hf_quant_config is not None: + quant_method = hf_quant_config.get("quant_method") + + # DeepSeek V3 FP8 ckpt + if quant_method == "fp8" and hf_quant_config.get( + "weight_block_size"): + self.quant_algo = QuantAlgo.FP8_BLOCK_SCALES + if moe_backend == 'TRTLLM': + # TODO: This is a hack. Remove after fp8 bmm is integrated. + self.exclude_modules = [ + "*kv_b_proj*", "*k_b_proj*", "*eh_proj" + ] + else: + self.exclude_modules = ["*eh_proj"] + + block_size = hf_quant_config.get("weight_block_size", []) + + assert tuple(block_size) == ( + 128, 128 + ), "FP8_BLOCK_SCALES only supports block_size=(128,128)" + self.group_size = block_size[0] + + elif quant_method == "mxfp4": + self.quant_algo = QuantConfig.get_mxfp4_quant_algo( + self.llm_args.moe_config.backend) + self.group_size = 32 + self.exclude_modules = [ + 'block.*.attn.out', 'block.*.mlp.gate', + 'block.*.attn.qkv', 'embedding', 'unembedding' + ] + elif quant_method == "modelopt": + mapped_new_args = self._map_new_to_legacy_args( + hf_quant_config) + return self._update_from_legacy_args( + mapped_new_args, moe_backend, model_ckpt_path) + else: + raise NotImplementedError( + f"Unsupported quantization_config: {hf_quant_config}.") + + return True, None + return False, None + + def _update_from_legacy_args(self, args, moe_backend: str, + checkpoint_dir) -> bool: + hf_quant_algo = args.pop("quant_algo", None) + layer_quant_config = None + + if hf_quant_algo is not None: + # fp8_pb_wo from modelopt is the same as fp8_block_scales + if hf_quant_algo == "fp8_pb_wo": + hf_quant_algo = QuantAlgo.FP8_BLOCK_SCALES + else: + hf_quant_algo = QuantAlgo(hf_quant_algo) + if self.quant_algo is None: + logger.info( + f"Setting quant_algo={hf_quant_algo} form HF quant config.") + self.quant_algo = hf_quant_algo + elif self.quant_algo != hf_quant_algo: + raise ValueError( + f"Specified quant_algo={self.quant_algo}, conflicting with quant_algo={hf_quant_algo} from HF quant config." + ) + else: + raise ValueError("Pre-quantized checkpoint must have quant_algo.") + + hf_kv_cache_quant_algo = args.pop("kv_cache_quant_algo", None) + + if hf_kv_cache_quant_algo is not None: + hf_kv_cache_quant_algo = QuantAlgo(hf_kv_cache_quant_algo) + + if self.kv_cache_quant_algo is None and hf_kv_cache_quant_algo is not None: + logger.info( + f"Setting kv_cache_quant_algo={hf_kv_cache_quant_algo} form HF quant config." + ) + self.kv_cache_quant_algo = hf_kv_cache_quant_algo + + elif self.kv_cache_quant_algo != hf_kv_cache_quant_algo: + raise ValueError( + f"Specified kv_cache_quant_algo={self.kv_cache_quant_algo}, conflicting with kv_cache_quant_algo={hf_kv_cache_quant_algo} from HF quant config." + ) + + if self.kv_cache_quant_algo not in [ + None, QuantAlgo.FP8, QuantAlgo.NVFP4 + ]: + raise ValueError( + f"Only kv_cache_quant_algo={QuantAlgo.FP8} or {QuantAlgo.NVFP4} is allowed for pre-quantized checkpoint, got {quant_config.kv_cache_quant_algo}." + ) + + if self.quant_algo == QuantAlgo.MIXED_PRECISION: + json_extended_quant_configs: dict = {} + + # Only attempt layer-merge if we know checkpoint_dir + if checkpoint_dir is not None: + try: + import transformers + mixed_quant_config_file = transformers.utils.hub.cached_file( + checkpoint_dir, 'quant_cfg.json') + with open(mixed_quant_config_file) as fm: + json_extended_quant_configs = json.load(fm) + except Exception: + logger.info( + "No quant_cfg.json found for layer quant info, using base quantization config." + ) + + # Merge extended info (if any) over base + merged_quant_configs = dict( + args) # todo we pop up some args so it moght not a good idea... + merged_quant_configs.update(json_extended_quant_configs) + + # kv_cache_quant_algo is global regardless of MIXED_PRECISION + kv_cache_quant_algo = merged_quant_configs.get( + 'kv_cache_quant_algo', None) + mixed_quant_configs = merged_quant_configs.get( + 'quantized_layers', None) + + # Consistency check if both sources specified kv_cache_quant_algo + if (kv_quant_lhs := json_extended_quant_configs.get("kv_cache_quant_algo", None)) is not None and \ + (kv_quant_rhs := self.kv_cache_quant_algo) is not None: + if kv_quant_lhs != kv_quant_rhs: + raise RuntimeError( + f"The kvcache config in 'quant_cfg.json', {kv_quant_lhs}, " + f"is different from the base config, {kv_quant_rhs}!") + + # Set the final global kv_cache_quant_algo + if "kv_cache_quant_algo" in merged_quant_configs: + logger.info( + f"Setting kv_cache_quant_algo={kv_quant_lhs} form quant config." + ) + self.kv_cache_quant_algo = merged_quant_configs[ + "kv_cache_quant_algo"] + + # Build per-layer QuantConfig objects + if mixed_quant_configs: + layer_quant_config = {} + for layer, layer_cfg in mixed_quant_configs.items(): + cfg = QuantConfig() + cfg.kv_cache_quant_algo = kv_cache_quant_algo + cfg.quant_algo = layer_cfg['quant_algo'] + cfg.group_size = layer_cfg.get('group_size', None) + layer_quant_config[layer] = cfg + + for arg, val in args.items(): + if not hasattr(self, arg): + raise ValueError(f"{arg} can't be found in quant config") + setattr(self, arg, val) + + if self.quant_algo == QuantAlgo.FP8_BLOCK_SCALES: + if self.group_size is None: + self.group_size = 128 + + if moe_backend == 'TRTLLM' and self.quant_algo == "FP8_BLOCK_SCALES" and self.exclude_modules is None: + self.exclude_modules = [ + "*kv_b_proj*", "*k_b_proj*", "*eh_proj" + ] # todo maybe merge or it ight be okay to override + + return True, layer_quant_config + + def _update_from_legacy_quant_config_json(self, path, moe_backend: str, + checkpoint_dir: Path): + with open(path, "r") as f: + hf_quant_config = json.load(f) + hf_quant_config = hf_quant_config["quantization"] + return self._update_from_legacy_args(hf_quant_config, moe_backend, + checkpoint_dir) + + def load_quant_config_from_dtypes_json(self, dtypes_json_file, + moe_backend: str): + layer_quant_config = None + + exclude_modules = set() + has_mxfp4 = False + is_dynamic_quant = False + with open(dtypes_json_file) as f: + dtypes_json = json.load(f) + for layer, dtype in dtypes_json.items(): + if layer.endswith("weight"): + if dtype == "BF16" or dtype == "FP16": + names = layer.split(".") + exclude_modules.add('.'.join(names[:-1])) + elif dtype == "MXFP4": + # This is the path for the fp8 checkpoint which requires dynamic quantization. + is_dynamic_quant = True + has_mxfp4 = True + elif layer.endswith("weight.blocks"): + scale_name = layer.replace("weight.blocks", "weight.scales") + scale_dtype = dtypes_json.get(scale_name, None) + assert scale_dtype == "UE8" + is_dynamic_quant = False + has_mxfp4 = True + + if has_mxfp4: + self.quant_algo = QuantConfig.get_mxfp4_quant_algo( + moe_backend, is_dynamic_quant) + self.group_size = 32 + self.exclude_modules = list(exclude_modules) + + return True, layer_quant_config + + def update_from_model_ckpt(self, model_ckpt_path: Path, moe_backend: str): + hf_quant_config_path = Path(model_ckpt_path / "hf_quant_config.json") + hf_config_path = Path(model_ckpt_path / "config.json") + quant_config_dtypes = Path(model_ckpt_path / 'dtypes.json') + if hf_quant_config_path.exists(): + logger.info( + f"Found {hf_quant_config_path}, pre-quantized checkpoint is used." + ) + return self._update_from_legacy_quant_config_json( + hf_quant_config_path, moe_backend, model_ckpt_path) + + elif hf_config_path.exists(): + logger.info( + f"Found {hf_config_path}, pre-quantized checkpoint is used.") + return self._update_from_quant_config_json(hf_config_path, + moe_backend, + model_ckpt_path) + + elif quant_config_dtypes.exists(): + logger.info( + f"Found {quant_config_dtypes}, pre-quantized checkpoint is used." + ) + self.load_quant_config_from_dtypes_json(quant_config_dtypes, + moe_backend) + + logger.warning(f"No quant config found in {model_ckpt_path}") + return False, None + def _get_quant_cfg(self, module_name=None): if self.exclude_modules is not None: for exclude_module in self.exclude_modules: diff --git a/tests/unittest/llmapi/test_llm_quant.py b/tests/unittest/llmapi/test_llm_quant.py index 573a8bf0ef9..7e8033ed52a 100644 --- a/tests/unittest/llmapi/test_llm_quant.py +++ b/tests/unittest/llmapi/test_llm_quant.py @@ -5,7 +5,6 @@ import pytest from tensorrt_llm._tensorrt_engine import LLM -from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm.llmapi import KvCacheConfig, SamplingParams from tensorrt_llm.llmapi.llm_utils import CalibConfig, QuantAlgo, QuantConfig @@ -114,12 +113,13 @@ def test_quant_cfg_from_quant_cfg_json(): } } - hf_quant_config_file = model_dir / "hf_quant_config.json" + hf_quant_config_file = Path(model_dir / "hf_quant_config.json") with open(hf_quant_config_file, 'w') as f: json.dump(hf_quant_config_content, f) - quant_config, layer_quant_config = ModelConfig.load_modelopt_quant_config( - hf_quant_config_file, model_dir, None) + quant_config = QuantConfig() + is_quant_config_changed, layer_quant_config = quant_config.update_from_model_ckpt( + model_ckpt_path=model_dir, moe_backend=None) # Verify quant_cfg.json was loaded assert quant_config.quant_algo == QuantAlgo.MIXED_PRECISION @@ -157,11 +157,13 @@ def test_quant_cfg_from_hf_quant_config(): } } } - hf_quant_config_file = model_dir / "hf_quant_config.json" + hf_quant_config_file = Path(model_dir / "hf_quant_config.json") with open(hf_quant_config_file, 'w') as f: json.dump(hf_quant_config_content, f) - quant_config, layer_quant_config = ModelConfig.load_modelopt_quant_config( - hf_quant_config_file, model_dir, None) + + quant_config = QuantConfig() + is_quant_config_changed, layer_quant_config = quant_config.update_from_model_ckpt( + model_ckpt_path=model_dir, moe_backend=None) # Verify layer configs assert quant_config.quant_algo == QuantAlgo.MIXED_PRECISION From bb7e1e1eaf597e788d3713660df4b2d3eb37fa85 Mon Sep 17 00:00:00 2001 From: Daniel Afrimi Date: Tue, 18 Nov 2025 13:29:12 +0000 Subject: [PATCH 2/5] wip Signed-off-by: Daniel Afrimi --- tensorrt_llm/models/modeling_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index bfccdcaa28f..9bc403818aa 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -423,8 +423,6 @@ def _update_from_legacy_args(self, args, moe_backend: str, layer_quant_config[layer] = cfg for arg, val in args.items(): - if not hasattr(self, arg): - raise ValueError(f"{arg} can't be found in quant config") setattr(self, arg, val) if self.quant_algo == QuantAlgo.FP8_BLOCK_SCALES: From 2facb5e86d1808514aa285e488a4752832f70c9d Mon Sep 17 00:00:00 2001 From: Daniel Afrimi Date: Tue, 18 Nov 2025 14:33:48 +0000 Subject: [PATCH 3/5] wip Signed-off-by: Daniel Afrimi --- tensorrt_llm/llmapi/llm_utils.py | 1 + tensorrt_llm/models/modeling_utils.py | 10 ++++------ 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 4b46b4f791c..8b2e208d444 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -25,6 +25,7 @@ from ..logger import logger from ..mapping import Mapping from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM +from ..models.modeling_utils import QuantAlgo # noqa: F401 from ..models.modeling_utils import PretrainedConfig, QuantConfig from ..module import Module from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 9bc403818aa..2a3b5d74da3 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -235,11 +235,11 @@ def _infer_kv_cache_quant_algo_from_scheme(kv_scheme: dict) -> str | None: bits = kv_scheme.get("num_bits") dynamic = bool(kv_scheme.get("dynamic", False)) - # todo add here all options... + # TODO (danielafrimi) needs to check all supported options... if kv_type == "float" and bits == 8 and not dynamic: - return QuantAlgo("FP8_BLOCK_SCALES") + return QuantAlgo.FP8 if kv_type in ("int", "uint") and bits == 8: - return QuantAlgo("INT8") + return QuantAlgo.INT8 return None def _map_new_to_legacy_args(self, hf_quant_config: dict) -> dict: @@ -261,8 +261,7 @@ def _map_new_to_legacy_args(self, hf_quant_config: dict) -> dict: hf_quant_config.get("ignore") or []) kv_scheme = hf_quant_config.get("kv_cache_scheme") or {} - kv_algo = QuantConfig._infer_kv_cache_quant_algo_from_scheme( - kv_scheme) # todo check it + kv_algo = QuantConfig._infer_kv_cache_quant_algo_from_scheme(kv_scheme) if kv_algo is not None: qunatization_dict["kv_cache_quant_algo"] = kv_algo @@ -273,7 +272,6 @@ def _map_new_to_legacy_args(self, hf_quant_config: dict) -> dict: if "symmetric" in hf_quant_config: qunatization_dict["zero_point"] = hf_quant_config["symmetric"] - # todo add here pre qunat scale and other keys.... return qunatization_dict def _update_from_quant_config_json(self, path, moe_backend: str, From 55bb2c6b7fcd01b0b2b05af20993c5f386690f99 Mon Sep 17 00:00:00 2001 From: Daniel Afrimi Date: Tue, 18 Nov 2025 15:05:53 +0000 Subject: [PATCH 4/5] wip Signed-off-by: Daniel Afrimi --- tensorrt_llm/models/modeling_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 2a3b5d74da3..7557896e0e3 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -235,7 +235,6 @@ def _infer_kv_cache_quant_algo_from_scheme(kv_scheme: dict) -> str | None: bits = kv_scheme.get("num_bits") dynamic = bool(kv_scheme.get("dynamic", False)) - # TODO (danielafrimi) needs to check all supported options... if kv_type == "float" and bits == 8 and not dynamic: return QuantAlgo.FP8 if kv_type in ("int", "uint") and bits == 8: @@ -384,8 +383,7 @@ def _update_from_legacy_args(self, args, moe_backend: str, ) # Merge extended info (if any) over base - merged_quant_configs = dict( - args) # todo we pop up some args so it moght not a good idea... + merged_quant_configs = dict(args) merged_quant_configs.update(json_extended_quant_configs) # kv_cache_quant_algo is global regardless of MIXED_PRECISION @@ -428,9 +426,7 @@ def _update_from_legacy_args(self, args, moe_backend: str, self.group_size = 128 if moe_backend == 'TRTLLM' and self.quant_algo == "FP8_BLOCK_SCALES" and self.exclude_modules is None: - self.exclude_modules = [ - "*kv_b_proj*", "*k_b_proj*", "*eh_proj" - ] # todo maybe merge or it ight be okay to override + self.exclude_modules = ["*kv_b_proj*", "*k_b_proj*", "*eh_proj"] return True, layer_quant_config From 9d736e7e8eb3927a7b62cc10b98c399ec0a87cc3 Mon Sep 17 00:00:00 2001 From: Daniel Afrimi Date: Tue, 18 Nov 2025 15:11:59 +0000 Subject: [PATCH 5/5] wip Signed-off-by: Daniel Afrimi --- tensorrt_llm/models/modeling_utils.py | 34 +++++++++++++-------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 7557896e0e3..59ab8448441 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -242,36 +242,36 @@ def _infer_kv_cache_quant_algo_from_scheme(kv_scheme: dict) -> str | None: return None def _map_new_to_legacy_args(self, hf_quant_config: dict) -> dict: - qunatization_dict = {} + quantization_dict = {} quant_algo = hf_quant_config.get("quant_algo") if quant_algo == "fp8_pb_wo": quant_algo = "FP8_BLOCK_SCALES" if quant_algo is not None: - qunatization_dict["quant_algo"] = quant_algo + quantization_dict["quant_algo"] = quant_algo if quant_algo == QuantAlgo.W4A16_AWQ or quant_algo == QuantAlgo.W4A8_AWQ: - qunatization_dict["pre_quant_scale"] = True + quantization_dict["pre_quant_scale"] = True if "group_size" in hf_quant_config: - qunatization_dict["group_size"] = hf_quant_config["group_size"] + quantization_dict["group_size"] = hf_quant_config["group_size"] if "ignore" in hf_quant_config: - qunatization_dict["exclude_modules"] = list( + quantization_dict["exclude_modules"] = list( hf_quant_config.get("ignore") or []) kv_scheme = hf_quant_config.get("kv_cache_scheme") or {} kv_algo = QuantConfig._infer_kv_cache_quant_algo_from_scheme(kv_scheme) if kv_algo is not None: - qunatization_dict["kv_cache_quant_algo"] = kv_algo + quantization_dict["kv_cache_quant_algo"] = kv_algo if "quantized_layers" in hf_quant_config: - qunatization_dict["quantized_layers"] = hf_quant_config[ + quantization_dict["quantized_layers"] = hf_quant_config[ "quantized_layers"] if "symmetric" in hf_quant_config: - qunatization_dict["zero_point"] = hf_quant_config["symmetric"] + quantization_dict["zero_point"] = hf_quant_config["symmetric"] - return qunatization_dict + return quantization_dict def _update_from_quant_config_json(self, path, moe_backend: str, model_ckpt_path) -> bool: @@ -321,8 +321,7 @@ def _update_from_quant_config_json(self, path, moe_backend: str, return True, None return False, None - def _update_from_legacy_args(self, args, moe_backend: str, - checkpoint_dir) -> bool: + def _update_from_legacy_args(self, args, moe_backend, checkpoint_dir): hf_quant_algo = args.pop("quant_algo", None) layer_quant_config = None @@ -363,7 +362,7 @@ def _update_from_legacy_args(self, args, moe_backend: str, None, QuantAlgo.FP8, QuantAlgo.NVFP4 ]: raise ValueError( - f"Only kv_cache_quant_algo={QuantAlgo.FP8} or {QuantAlgo.NVFP4} is allowed for pre-quantized checkpoint, got {quant_config.kv_cache_quant_algo}." + f"Only kv_cache_quant_algo={QuantAlgo.FP8} or {QuantAlgo.NVFP4} is allowed for pre-quantized checkpoint, got {self.kv_cache_quant_algo}." ) if self.quant_algo == QuantAlgo.MIXED_PRECISION: @@ -430,16 +429,15 @@ def _update_from_legacy_args(self, args, moe_backend: str, return True, layer_quant_config - def _update_from_legacy_quant_config_json(self, path, moe_backend: str, - checkpoint_dir: Path): + def _update_from_legacy_quant_config_json(self, path, moe_backend, + checkpoint_dir): with open(path, "r") as f: hf_quant_config = json.load(f) hf_quant_config = hf_quant_config["quantization"] return self._update_from_legacy_args(hf_quant_config, moe_backend, checkpoint_dir) - def load_quant_config_from_dtypes_json(self, dtypes_json_file, - moe_backend: str): + def load_quant_config_from_dtypes_json(self, dtypes_json_file, moe_backend): layer_quant_config = None exclude_modules = set() @@ -493,8 +491,8 @@ def update_from_model_ckpt(self, model_ckpt_path: Path, moe_backend: str): logger.info( f"Found {quant_config_dtypes}, pre-quantized checkpoint is used." ) - self.load_quant_config_from_dtypes_json(quant_config_dtypes, - moe_backend) + return self.load_quant_config_from_dtypes_json( + quant_config_dtypes, moe_backend) logger.warning(f"No quant config found in {model_ckpt_path}") return False, None