Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 9 additions & 178 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import contextlib
import json
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -14,15 +12,14 @@
from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.config_utils import (is_nemotron_hybrid,
load_pretrained_config)
from tensorrt_llm._utils import get_sm_version, torch_dtype_to_binding
from tensorrt_llm._utils import torch_dtype_to_binding
from tensorrt_llm.bindings import LayerType as LayerTypeCpp
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.llmapi.llm_args import (DeepSeekSparseAttentionConfig,
MoeLoadBalancerConfig)
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo

TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)

Expand Down Expand Up @@ -219,173 +216,6 @@ def is_generation_model(model_architectures: Optional[List[str]],
# TODO: should be 'not model_type == ModelType.ENCODER_ONLY'
# once ModelType is used in pytorch flow.

@staticmethod
def load_modelopt_quant_config(quant_config_file, checkpoint_dir,
moe_backend):
quant_config = QuantConfig()
layer_quant_config = None

with open(quant_config_file) as f:
quant_config_dict = json.load(f)

json_quant_configs = quant_config_dict['quantization']

quant_config.quant_algo = json_quant_configs.get('quant_algo', None)
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
if quant_config.quant_algo == "fp8_pb_wo":
quant_config.quant_algo = 'FP8_BLOCK_SCALES'
quant_config.kv_cache_quant_algo = json_quant_configs.get(
'kv_cache_quant_algo', None)
quant_config.group_size = json_quant_configs.get('group_size', None)
quant_config.exclude_modules = json_quant_configs.get(
'exclude_modules', None)

if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION:
json_extended_quant_configs: dict = {}
# See tests/unittest/llmapi/test_llm_quant.py
try:
mixed_quant_config_file = transformers.utils.hub.cached_file(
checkpoint_dir, 'quant_cfg.json')
with open(mixed_quant_config_file) as fm:
json_extended_quant_configs = json.load(fm)
except Exception:
logger.info(
f"No quant_cfg.json found for layer quant info, using hf_quant_config.json."
)
json_quant_configs.update(json_extended_quant_configs)
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
kv_cache_quant_algo = json_quant_configs.get(
'kv_cache_quant_algo', None)
mixed_quant_configs = json_quant_configs.get(
'quantized_layers', None)
if (kv_quant_lhs := json_extended_quant_configs.get(
"kv_cache_quant_algo", None)) is not None and (
kv_quant_rhs :=
quant_config.kv_cache_quant_algo) is not None:
if kv_quant_lhs != kv_quant_rhs:
raise RuntimeError(
f"The kvcache config in 'quant_cfg.json', {kv_quant_lhs},"
f"is different from 'hf_quant_config.json', {kv_quant_rhs}!"
)
quant_config.kv_cache_quant_algo = json_quant_configs[
"kv_cache_quant_algo"]
for layer in mixed_quant_configs:
config = QuantConfig()
config.kv_cache_quant_algo = kv_cache_quant_algo
config.quant_algo = mixed_quant_configs[layer]['quant_algo']
config.group_size = mixed_quant_configs[layer].get(
'group_size', None)
mixed_quant_configs[layer] = config
layer_quant_config = mixed_quant_configs
elif quant_config.quant_algo == QuantAlgo.FP8_BLOCK_SCALES:
if quant_config.group_size is None:
quant_config.group_size = 128

if moe_backend == 'TRTLLM' and quant_config.quant_algo == "FP8_BLOCK_SCALES" and quant_config.exclude_modules is None:
quant_config.exclude_modules = [
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
]
return quant_config, layer_quant_config

@staticmethod
def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False):
quant_algo = ModelConfig.override_quant_algo()
if quant_algo is None and not is_dynamic_quant:
if get_sm_version() >= 100:
if moe_backend == 'TRITON':
return QuantAlgo.W4A8_MXFP4_FP8
else:
return QuantAlgo.W4A8_MXFP4_MXFP8
else:
return QuantAlgo.W4A16_MXFP4
else:
return quant_algo

@staticmethod
def load_hf_quant_config(hf_quant_config, moe_backend):
quant_config = QuantConfig()
layer_quant_config = None

# DeepSeek V3 FP8 ckpt
if hf_quant_config.get("quant_method") == "fp8" and hf_quant_config.get(
"weight_block_size", []):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
if moe_backend == 'TRTLLM':
# TODO: This is a hack. Remove after fp8 bmm is integrated.
quant_config.exclude_modules = [
"*kv_b_proj*", "*k_b_proj*", "*eh_proj"
]
else:
quant_config.exclude_modules = ["*eh_proj"]

block_size = hf_quant_config.get("weight_block_size", [])
assert tuple(block_size) == (
128, 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
quant_config.group_size = block_size[0]
# MXFP4 checkpoints.
elif hf_quant_config.get("quant_method") == "mxfp4":
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
moe_backend)
quant_config.group_size = 32
quant_config.exclude_modules = [
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
'embedding', 'unembedding'
]

return quant_config, layer_quant_config

@staticmethod
def load_quant_config_from_dtypes_json(dtypes_json_file, moe_backend: str):
quant_config = QuantConfig()
layer_quant_config = None

exclude_modules = set()
has_mxfp4 = False
is_dynamic_quant = False
with open(dtypes_json_file) as f:
dtypes_json = json.load(f)
for layer, dtype in dtypes_json.items():
if layer.endswith("weight"):
if dtype == "BF16" or dtype == "FP16":
names = layer.split(".")
exclude_modules.add('.'.join(names[:-1]))
elif dtype == "MXFP4":
# This is the path for the fp8 checkpoint which requires dynamic quantization.
is_dynamic_quant = True
has_mxfp4 = True
elif layer.endswith("weight.blocks"):
scale_name = layer.replace("weight.blocks", "weight.scales")
scale_dtype = dtypes_json.get(scale_name, None)
assert scale_dtype == "UE8"
is_dynamic_quant = False
has_mxfp4 = True

if has_mxfp4:
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
moe_backend, is_dynamic_quant)
quant_config.group_size = 32
quant_config.exclude_modules = list(exclude_modules)
logger.info(f"Setting quant_config: {quant_config}")

return quant_config, layer_quant_config

@staticmethod
def override_quant_algo():
new_algo = os.environ.get("OVERRIDE_QUANT_ALGO", None)
supported_algos = {
"W4A16_MXFP4": QuantAlgo.W4A16_MXFP4,
"W4A8_MXFP4_MXFP8": QuantAlgo.W4A8_MXFP4_MXFP8,
"W4A8_MXFP4_FP8": QuantAlgo.W4A8_MXFP4_FP8,
}
if new_algo is not None:
if new_algo.upper() in supported_algos:
return supported_algos[new_algo.upper()]
else:
logger.warning(
f"Unsupported quant algo: {new_algo}, supported algos: {supported_algos.keys()}"
)
return None

@classmethod
def from_pretrained(cls,
checkpoint_dir: str,
Expand Down Expand Up @@ -445,16 +275,17 @@ def cached_file(path_or_repo_id, file_name):
# quantized ckpt in modelopt format
if quant_config_file := cached_file(checkpoint_dir,
'hf_quant_config.json'):
quant_config, layer_quant_config = cls.load_modelopt_quant_config(
quant_config_file, checkpoint_dir, moe_backend)
is_quant_config_changed, layer_quant_config = quant_config.update_from_model_ckpt(
checkpoint_dir, moe_backend)

# quantized ckpt in other formats
elif hasattr(pretrained_config, "quantization_config"):
hf_quant_config = pretrained_config.quantization_config
quant_config, layer_quant_config = cls.load_hf_quant_config(
hf_quant_config, moe_backend)
is_quant_config_changed, layer_quant_config = quant_config.update_from_model_ckpt(
checkpoint_dir, moe_backend)

elif quant_config_file := cached_file(checkpoint_dir, 'dtypes.json'):
quant_config, layer_quant_config = cls.load_quant_config_from_dtypes_json(
quant_config_file, moe_backend)
is_quant_config_changed, layer_quant_config = quant_config.update_from_model_ckpt(
checkpoint_dir, moe_backend)

model_config = cls(pretrained_config=pretrained_config,
quant_config=quant_config,
Expand Down
105 changes: 5 additions & 100 deletions tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from ..logger import logger
from ..mapping import Mapping
from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM
from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig
from ..models.modeling_utils import QuantAlgo # noqa: F401
from ..models.modeling_utils import PretrainedConfig, QuantConfig
from ..module import Module
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
get_build_cache_config_from_env)
Expand Down Expand Up @@ -332,106 +333,10 @@ def _download_hf_model(self):
assert self.speculative_model_obj.is_local_model

def _update_from_hf_quant_config(self) -> bool:
"""Update quant_config from the config file of pre-quantized HF checkpoint.

Returns:
prequantized (bool): Whether the checkpoint is pre-quantized.
"""
quant_config = self.llm_args.quant_config

hf_quant_config_path = f"{self._model_dir}/hf_quant_config.json"
if os.path.exists(hf_quant_config_path):
logger.info(
f"Found {hf_quant_config_path}, pre-quantized checkpoint is used."
)
with open(hf_quant_config_path, "r") as f:
hf_quant_config = json.load(f)
hf_quant_config = hf_quant_config["quantization"]

hf_quant_algo = hf_quant_config.pop("quant_algo", None)
if hf_quant_algo is not None:
# fp8_pb_wo from modelopt is the same as fp8_block_scales
if hf_quant_algo == "fp8_pb_wo":
hf_quant_algo = QuantAlgo.FP8_BLOCK_SCALES
else:
hf_quant_algo = QuantAlgo(hf_quant_algo)
if quant_config.quant_algo is None:
logger.info(
f"Setting quant_algo={hf_quant_algo} form HF quant config."
)
quant_config.quant_algo = hf_quant_algo
elif quant_config.quant_algo != hf_quant_algo:
raise ValueError(
f"Specified quant_algo={quant_config.quant_algo}, conflicting with quant_algo={hf_quant_algo} from HF quant config."
)
else:
raise ValueError(
"Pre-quantized checkpoint must have quant_algo.")

hf_kv_cache_quant_algo = hf_quant_config.pop(
"kv_cache_quant_algo", None)
if hf_kv_cache_quant_algo is not None:
hf_kv_cache_quant_algo = QuantAlgo(hf_kv_cache_quant_algo)
if quant_config.kv_cache_quant_algo is None:
logger.info(
f"Setting kv_cache_quant_algo={hf_kv_cache_quant_algo} form HF quant config."
)
quant_config.kv_cache_quant_algo = hf_kv_cache_quant_algo
elif quant_config.kv_cache_quant_algo != hf_kv_cache_quant_algo:
raise ValueError(
f"Specified kv_cache_quant_algo={quant_config.kv_cache_quant_algo}, conflicting with kv_cache_quant_algo={hf_kv_cache_quant_algo} from HF quant config."
)
else:
if quant_config.kv_cache_quant_algo not in [
None, QuantAlgo.FP8, QuantAlgo.NVFP4
]:
raise ValueError(
f"Only kv_cache_quant_algo={QuantAlgo.FP8} or {QuantAlgo.NVFP4} is allowed for pre-quantized checkpoint, got {quant_config.kv_cache_quant_algo}."
)

for key, value in hf_quant_config.items():
logger.info(
f"Setting {key}={str(value)[:100]}{'...' if len(str(value)) > 100 else ''} from HF quant config."
)
setattr(quant_config, key, value)

# Update the quant_config in llm_args for pytorch
self.llm_args.quant_config = quant_config

return True

hf_config_path = f"{self._model_dir}/config.json"
if os.path.exists(hf_config_path):
with open(hf_config_path, "r") as f:
hf_config = json.load(f)
hf_quant_config = hf_config.get("quantization_config", None)

if hf_quant_config is not None:
logger.info(
f"Found quantization_config field in {hf_config_path}, pre-quantized checkpoint is used."
)
# DeepSeek V3 FP8 ckpt
if hf_quant_config.get(
"quant_method") == "fp8" and hf_quant_config.get(
"weight_block_size"):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
elif hf_quant_config.get("quant_method") == "mxfp4":
from .._torch.model_config import ModelConfig
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
self.llm_args.moe_config.backend)
quant_config.group_size = 32
quant_config.exclude_modules = [
'block.*.attn.out', 'block.*.mlp.gate',
'block.*.attn.qkv', 'embedding', 'unembedding'
]
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")

return True

return False
quant_config_changed, _ = quant_config.update_from_model_ckpt(
self._model_dir, self.llm_args.moe_config.backend)
return quant_config_changed

def _load_model_from_hf(self):
''' Load a TRT-LLM model from a HF model. '''
Expand Down
Loading