Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.quantization.mode import ActivationScheme, QuantAlgo

TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)

Expand Down Expand Up @@ -368,6 +368,81 @@ def load_hf_quant_config(hf_quant_config, moe_backend):
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
'embedding', 'unembedding'
]
# FP8 per-tensor checkpoints.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these changes, I'd recommend adding tests in tests/unittest/llmapi/test_llm_quant.py to test ModelConfig.load_hf_quant_config() where you modified, extending the current tests which only have ModelConfig.load_modelopt_quant_config(). The reasons is to guard against any accidental breakage in the future.

elif hf_quant_config.get("quant_method") == "fp8":
Copy link
Collaborator

@QiJune QiJune Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some comments like "AngleSlim fp8 ckpt"? A reference:

# DeepSeek V3 FP8 ckpt

quant_config.quant_algo = QuantAlgo.FP8
# W4A8_AWQ checkpoints.
elif hf_quant_config.get("quant_method") == "w4a8_awq":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some comments like "AngleSlim w4a8_awq ckpt"?

quant_config.quant_algo = QuantAlgo.W4A8_AWQ
quant_config.group_size = hf_quant_config.get(
"weight_group_size", 128)
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")

# set kv_cache_quant_algo
Copy link
Collaborator

@QiJune QiJune Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please make the following codes into a single function, something like load_angleslim_config?

# DeepSeek V3 FP8 ckpt
if hf_quant_config.get("quant_method") == "fp8" xxx:
     xxx
# MXFP4 checkpoints.
elif hf_quant_config.get("quant_method") == "mxfp4":
    xxx
# Angleslim FP8 checkpoint
elif hf_quant_config.get("quant_method") == "fp8" :
    xxx
# Angleslim w4a8_awq checkpoint
elif hf_quant_config.get("quant_method") == "w4a8_awq":
    quant_config = load_angleslim_config(xxx)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My PR refactors the codebase of how create the quant_config

quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \
if hf_quant_config.get("kv_cache_quant_method") else None
# set activation_scheme
quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \
if hf_quant_config.get("activation_scheme") else None
# set exclude_modules
if quant_config.exclude_modules:
if hf_quant_config.get("ignored_layers"):
quant_config.exclude_modules += hf_quant_config.get(
"ignored_layers")
else:
quant_config.exclude_modules = hf_quant_config.get("ignored_layers")

# set exclude_quantization
hf_ignored_quantization_config = hf_quant_config.get(
"ignored_quantization_config")
if hf_ignored_quantization_config:
quant_config.exclude_quantization = {
"kv_cache_quant_algo":
QuantAlgo(
hf_ignored_quantization_config.get(
"kv_cache_quant_method").upper())
if hf_ignored_quantization_config.get("kv_cache_quant_method")
else None,
"activation_scheme":
ActivationScheme(
hf_ignored_quantization_config.get(
"activation_scheme").upper())
if hf_ignored_quantization_config.get("activation_scheme") else
None,
"group_size":
128,
}
if hf_ignored_quantization_config.get(
"quant_method"
) == "fp8" and hf_ignored_quantization_config.get(
"weight_block_size", []):
quant_config.exclude_quantization[
"quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES
block_size = hf_ignored_quantization_config.get(
"weight_block_size", [])
assert tuple(block_size) == (
128,
128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
quant_config.exclude_quantization["group_size"] = block_size[0]
elif hf_ignored_quantization_config.get("quant_method") == "fp8":
quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8
elif hf_ignored_quantization_config.get(
"quant_method") == "w4a8_awq":
quant_config.exclude_quantization[
"quant_algo"] = QuantAlgo.W4A8_AWQ
quant_config.exclude_quantization[
"group_size"] = hf_ignored_quantization_config.get(
"weight_group_size", 128)
else:
raise NotImplementedError(
f"Unsupported quantization_config.ignored_quantization_config: "
f"{hf_ignored_quantization_config}.")

logger.info(
f"Load quantization config from pretrained config, quant_config: {quant_config}"
)

return quant_config, layer_quant_config

Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
if names[-1] == "kv_b_proj":
# TODO: remove weight_dequant after enabling fp8_bmm
dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization(
names[-1])
names[-1]
) and self.model_config.quant_config.exclude_quantization is None
if dequant_kv_b_proj:
kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant(
name)
Expand Down
16 changes: 15 additions & 1 deletion tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,23 @@ def apply_quant_config_exclude_modules(self):
"""
quant_config = self.model_config.quant_config
kv_cache_quant_algo = None
quant_algo = None
activation_scheme = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is activation_scheme left as a placeholder for now or we have corresponding forward logic for this?

group_size = 128
if quant_config:
kv_cache_quant_algo = quant_config.kv_cache_quant_algo
new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)
exclude_quantization = quant_config.exclude_quantization
if exclude_quantization:
quant_algo = exclude_quantization.get("quant_algo", None)
activation_scheme = exclude_quantization.get(
"activation_scheme", None)
group_size = exclude_quantization.get("group_size", 128)
new_config = QuantConfig(
quant_algo=quant_algo,
kv_cache_quant_algo=kv_cache_quant_algo,
activation_scheme=activation_scheme,
group_size=group_size,
)

if quant_config is not None:
if quant_config.exclude_modules is not None:
Expand Down
78 changes: 64 additions & 14 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def load_expert_weights_to_dst(
MoEWeightLoadingMode.VANILLA,
MoEWeightLoadingMode.W4A8_CUSTOM
]:
w1_weight = weights[f"{expert_id}.w1.weight"]
w3_weight = weights[f"{expert_id}.w3.weight"]
w2_weight = weights[f"{expert_id}.w2.weight"]
weight_name = "qweight" if f"{expert_id}.w1.qweight" in weights else "weight"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Barry-Delaney @rosenrodt to help review this part.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think add another MoEWeightLoadingMode will be better. Also added in another comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If refactoring into different MoEWeightLoadingModes, please also be aware you would need to put extra flag during create_moe(). For DS example: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/models/modeling_deepseekv3.py#L840

self.experts = create_moe(..., weight_loading_mode=...)

w1_weight = weights[f"{expert_id}.w1.{weight_name}"]
w3_weight = weights[f"{expert_id}.w3.{weight_name}"]
w2_weight = weights[f"{expert_id}.w2.{weight_name}"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For changes in quantization.py, we should add respective tests in
tests/unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8 to avoid breaking this feature in the future

if module.bias:
w1_bias = weights[f"{expert_id}.w1.bias"]
w3_bias = weights[f"{expert_id}.w3.bias"]
Expand Down Expand Up @@ -1085,6 +1086,10 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
w4a8_custom = module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, MoEWeightLoadingMode.VANILLA is used for ModelOpt, and MoEWeightLoadingMode.W4A8_CUSTOM is used for the checkpoint produced by TRT-LLM scripts, is it okay to add MoEWeightLoadingMode.ANGELSIM or something similar to distinguish the following logics, or it's safe to reuse the TRT-LLM one?

if w4a8_custom:
weight_scale_name = "weight_scale_inv"
for expert_id in module.initial_local_expert_ids:
if f"{expert_id}.w3.weight_scale.int4" in weights:
weight_scale_name = "weight_scale.int4"
break
else:
weight_scale_name = "weight_scale"

Expand All @@ -1107,14 +1112,36 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
all_w3_w1_input_scales_max = torch.max(
torch.stack(all_w3_input_scales),
torch.stack(all_w1_input_scales)).max()
all_w3_w1_scales_fp8_max = None
has_fp8_weight_scale = False
if w4a8_custom:
# In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale
module.fc31_act_scale.data.copy_(
torch.ones_like(module.fc31_act_scale, device=self.device) *
(1 / all_w3_w1_input_scales_max))

for expert_id in module.initial_local_expert_ids:
if f"{expert_id}.w1.weight_scale" in weights:
has_fp8_weight_scale = True
break
if has_fp8_weight_scale:
all_w3_w1_scales_fp8_max = []
for expert_id in module.initial_local_expert_ids:
w1_weight_scale_fp8 = load_weight_shard(
weights[f"{expert_id}.w1.weight_scale"],
device=self.device)
w3_weight_scale_fp8 = load_weight_shard(
weights[f"{expert_id}.w3.weight_scale"],
device=self.device)
all_w3_w1_scales_fp8_max.append(
torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8))
all_w3_w1_scales_fp8_max = torch.stack(
all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape)
else:
all_w3_w1_scales_fp8_max = torch.ones_like(module.fc31_alpha,
device=self.device)
module.fc31_alpha.data.copy_(
(torch.ones_like(module.fc31_alpha, device=self.device) *
all_w3_w1_input_scales_max).float())
(all_w3_w1_scales_fp8_max * all_w3_w1_input_scales_max).float())
else:
# In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored
all_w3_pre_quant_scales = [
Expand Down Expand Up @@ -1192,6 +1219,9 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
if w4a8_custom and has_fp8_weight_scale:
all_w3_scales = torch.stack(
all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
all_w1_scales = [
load_weight_shard(weights[f"{expert_id}.w1.{weight_scale_name}"],
module.tp_size,
Expand All @@ -1200,9 +1230,15 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
torch.stack(all_w1_scales)], dim=-2)
if w4a8_custom and has_fp8_weight_scale:
all_w1_scales = torch.stack(
all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
all_w3_w1_scales = torch.cat([all_w3_scales, all_w1_scales], dim=-2)
else:
all_w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
torch.stack(all_w1_scales)],
dim=-2)
if module.sm_version == 89:
w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype)
else:
Expand Down Expand Up @@ -1234,15 +1270,26 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
all_w2_input_scales_max = torch.stack(all_w2_input_scales).to(
module.dtype).max()

all_w2_scales_fp8 = None
if w4a8_custom:
# In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale
module.fc2_act_scale.data.copy_(
torch.ones_like(module.fc2_act_scale, device=self.device) *
(1 / all_w2_input_scales_max))
# In custom W4A8 ckpt, per-tensor weight_scale_2 is fused into alpha
if has_fp8_weight_scale:
all_w2_scales_fp8 = [
load_weight_shard(weights[f"{expert_id}.w2.weight_scale"],
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(
module.fc2_alpha.shape)
else:
all_w2_scales_fp8 = torch.ones_like(module.fc2_alpha,
device=self.device)
module.fc2_alpha.data.copy_(
(torch.ones_like(module.fc2_alpha, device=self.device) *
all_w2_input_scales_max).float())
(all_w2_scales_fp8 * all_w2_input_scales_max).float())
else:
# In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored
all_w2_pre_quant_scales = [
Expand Down Expand Up @@ -1288,12 +1335,15 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
if w4a8_custom and has_fp8_weight_scale:
all_w2_scales = torch.stack(
all_w2_scales) / all_w2_scales_fp8.unsqueeze(2)
else:
all_w2_scales = torch.stack(all_w2_scales)
if module.sm_version == 89:
w2_scales = torch.stack(all_w2_scales).to(torch.float16).view(
module.dtype)
w2_scales = all_w2_scales.to(torch.float16).view(module.dtype)
else:
w2_scales = torch.stack(all_w2_scales).to(torch.bfloat16).view(
module.dtype)
w2_scales = all_w2_scales.to(torch.bfloat16).view(module.dtype)

if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w2_scales = w2_scales.permute(1, 2, 0)
Expand Down
55 changes: 54 additions & 1 deletion tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from ..logger import logger
from ..mapping import Mapping
from ..models.automodel import MODEL_MAP, AutoConfig, AutoModelForCausalLM
from ..models.modeling_utils import PretrainedConfig, QuantAlgo, QuantConfig
from ..models.modeling_utils import (ActivationScheme, PretrainedConfig,
QuantAlgo, QuantConfig)
from ..module import Module
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
get_build_cache_config_from_env)
Expand Down Expand Up @@ -416,6 +417,11 @@ def _update_from_hf_quant_config(self) -> bool:
"weight_block_size"):
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
quant_config.exclude_modules = ["*eh_proj"]
block_size = hf_quant_config.get("weight_block_size", [])
assert tuple(block_size) == (
128, 128
), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
quant_config.group_size = block_size[0]
elif hf_quant_config.get("quant_method") == "mxfp4":
from .._torch.model_config import ModelConfig
quant_config.quant_algo = ModelConfig.get_mxfp4_quant_algo(
Expand All @@ -425,10 +431,57 @@ def _update_from_hf_quant_config(self) -> bool:
'block.*.attn.out', 'block.*.mlp.gate',
'block.*.attn.qkv', 'embedding', 'unembedding'
]
elif hf_quant_config.get("quant_method") == "fp8":
quant_config.quant_algo = QuantAlgo.FP8
elif hf_quant_config.get("quant_method") == "w4a8_awq":
quant_config.quant_algo = QuantAlgo.W4A8_AWQ
quant_config.group_size = hf_quant_config.get(
"weight_group_size", 128)
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")

# set kv_cache_quant_algo
quant_config.kv_cache_quant_algo = QuantAlgo(
hf_quant_config.get("kv_cache_quant_method").upper()
) if hf_quant_config.get("kv_cache_quant_method") else None
# set activation_scheme
quant_config.activation_scheme = ActivationScheme(
hf_quant_config.get("activation_scheme").upper()
) if hf_quant_config.get("activation_scheme") else None
# set exclude_modules
if quant_config.exclude_modules:
if hf_quant_config.get("ignored_modules"):
quant_config.exclude_modules += hf_quant_config.get(
"ignored_modules")
else:
quant_config.exclude_modules = hf_quant_config.get(
"ignored_modules")
# set exclude_quantization
hf_ignored_quantization_config = hf_quant_config.get(
"ignored_quantization_config")
if hf_ignored_quantization_config:
quant_config.exclude_quantization = {
"quant_algo":
QuantAlgo(
hf_ignored_quantization_config.get(
"quant_method").upper())
if hf_ignored_quantization_config.get("quant_method")
else None,
"kv_cache_quant_algo":
QuantAlgo(
hf_ignored_quantization_config.get(
"kv_cache_quant_method").upper())
if hf_ignored_quantization_config.get(
"kv_cache_quant_method") else None,
"activation_scheme":
ActivationScheme(
hf_ignored_quantization_config.get(
"activation_scheme").upper()) if
hf_ignored_quantization_config.get("activation_scheme")
else None,
}
logger.info(f"Detected quantization_config: {quant_config}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as i mentioned before my PR which refactors how we set the quantization config will changed. so if my PR will be merged before a rebase will needed.

in my PR we avoid of such a code duplication like this, so we can align on one source of truth.

return True

return False
Expand Down
7 changes: 6 additions & 1 deletion tensorrt_llm/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
WeightOnlyQuantLinear,
WeightOnlyQuantRowLinear)
from ..quantization.mode import (KV_CACHE_QUANT_ALGO_LIST, QUANT_ALGO_LIST,
W8A8_SQ_PLUGIN_LIST, QuantAlgo)
W8A8_SQ_PLUGIN_LIST, ActivationScheme,
QuantAlgo)
from ..quantization.utils import fp4_utils
from ..top_model_mixin import TopModelMixin
from .convert_utils import weight_only_quantize_dict
Expand Down Expand Up @@ -143,6 +144,8 @@ class QuantConfig:
pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False.
exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None.
mamba_ssm_cache_dtype (str, optional): The data type for mamba SSM cache. Defaults to None.
exclude_quantization (Dict, optional): The model of exclude_modules will use exclude_quantization.
Copy link
Collaborator

@syuoni syuoni Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain more on this field? For example, what are the types and meanings of the dict?

I guess it specifies the quant_config for "excluded modules", if so, could we use LayerQuantConfig instead? (Please refer to #8617, better to reuse the exiting logic, but it's not mandatory)

activation_scheme (tensorrt_llm.quantization.mode.ActivationScheme, optional): The input of activation quantize scheme.
"""
quant_algo: Optional[QuantAlgo] = None
kv_cache_quant_algo: Optional[QuantAlgo] = None
Expand All @@ -154,6 +157,8 @@ class QuantConfig:
pre_quant_scale: bool = False
exclude_modules: Optional[List[str]] = None
mamba_ssm_cache_dtype: Optional[str] = None
exclude_quantization: Optional[Dict] = None
activation_scheme: Optional[ActivationScheme] = None

@cached_property
def quant_mode(self) -> QuantModeWrapper:
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,8 @@ class GroupwiseQuantAlgo:
PRE_QUANT_SCALE = 4
W4A8_ALPHA = 8
INT8_WEIGHT = 16


class ActivationScheme(StrEnum, metaclass=BaseEnumMeta):
STATIC = auto()
DYNAMIC = auto()
6 changes: 6 additions & 0 deletions tests/unittest/api_stability/references/quant_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ methods:
use_meta_recipe:
annotation: bool
default: false
exclude_quantization:
annotation: Optional[dict]
default: null
activation_scheme:
annotation: Optional[tensorrt_llm.quantization.mode.ActivationScheme]
default: null
return_annotation: None
from_dict:
parameters:
Expand Down