Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 142 additions & 1 deletion tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
from tensorrt_llm.quantization.mode import ActivationScheme, QuantAlgo

TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig)

Expand Down Expand Up @@ -324,6 +324,70 @@ def load_modelopt_quant_config(quant_config_file, checkpoint_dir,
]
return quant_config, layer_quant_config

@staticmethod
def load_angelslim_quant_config(quant_config_file, checkpoint_dir,
moe_backend):
quant_config = QuantConfig()
layer_quant_config = None

with open(quant_config_file) as f:
quant_config_dict = json.load(f)

json_quant_configs = quant_config_dict['quantization']

quant_config.quant_algo = QuantAlgo(
json_quant_configs.get(
'quant_algo',
None).upper()) if json_quant_configs.get("quant_algo") else None
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
if quant_config.quant_algo == "fp8_pb_wo":
quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES')

quant_config.kv_cache_quant_algo = QuantAlgo(
json_quant_configs.get("kv_cache_quant_algo").upper()
) if json_quant_configs.get("kv_cache_quant_algo") else None
quant_config.group_size = json_quant_configs.get('group_size', None)
quant_config.exclude_modules = json_quant_configs.get(
'exclude_modules', None)
quant_config.activation_scheme = ActivationScheme(
json_quant_configs.get('activation_scheme', None).upper()
) if json_quant_configs.get("activation_scheme") else None

json_exclude_quantization = json_quant_configs.get(
'exclude_quantization', None)
if json_exclude_quantization:
quant_config.exclude_quant_config = {
"quant_algo":
QuantAlgo(
json_exclude_quantization.get('quant_algo', None).upper())
if json_exclude_quantization.get("quant_algo") else None,
"kv_cache_quant_algo":
QuantAlgo(
json_exclude_quantization.get(
"kv_cache_quant_algo").upper()) if
json_exclude_quantization.get("kv_cache_quant_algo") else None,
"activation_scheme":
ActivationScheme(
json_exclude_quantization.get('activation_scheme',
None).upper())
if json_exclude_quantization.get("activation_scheme") else None,
"group_size":
json_exclude_quantization.get('group_size', None),
}
if quant_config.exclude_quantization["quant_algo"] in [
QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ
]:
if quant_config.exclude_quantization["group_size"] is None:
quant_config.exclude_quantization["group_size"] = 128

if quant_config.quant_algo in [
QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ
]:
if quant_config.group_size is None:
quant_config.group_size = 128

return quant_config, layer_quant_config

@staticmethod
def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False):
quant_algo = ModelConfig.override_quant_algo()
Expand Down Expand Up @@ -368,6 +432,79 @@ def load_hf_quant_config(hf_quant_config, moe_backend):
'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv',
'embedding', 'unembedding'
]
elif hf_quant_config.get("quant_method") == "fp8":
Copy link
Collaborator

@QiJune QiJune Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some comments like "AngleSlim fp8 ckpt"? A reference:

# DeepSeek V3 FP8 ckpt

quant_config.quant_algo = QuantAlgo.FP8
elif hf_quant_config.get("quant_method") == "w4a8_awq":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add some comments like "AngleSlim w4a8_awq ckpt"?

quant_config.quant_algo = QuantAlgo.W4A8_AWQ
quant_config.group_size = hf_quant_config.get(
"weight_group_size", 128)
else:
raise NotImplementedError(
f"Unsupported quantization_config: {hf_quant_config}.")

# set kv_cache_quant_algo
Copy link
Collaborator

@QiJune QiJune Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please make the following codes into a single function, something like load_angleslim_config?

# DeepSeek V3 FP8 ckpt
if hf_quant_config.get("quant_method") == "fp8" xxx:
     xxx
# MXFP4 checkpoints.
elif hf_quant_config.get("quant_method") == "mxfp4":
    xxx
# Angleslim FP8 checkpoint
elif hf_quant_config.get("quant_method") == "fp8" :
    xxx
# Angleslim w4a8_awq checkpoint
elif hf_quant_config.get("quant_method") == "w4a8_awq":
    quant_config = load_angleslim_config(xxx)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My PR refactors the codebase of how create the quant_config

quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \
if hf_quant_config.get("kv_cache_quant_method") else None
# set activation_scheme
quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \
if hf_quant_config.get("activation_scheme") else None
# set exclude_modules
if quant_config.exclude_modules:
if hf_quant_config.get("ignored_layers"):
quant_config.exclude_modules += hf_quant_config.get(
"ignored_layers")
else:
quant_config.exclude_modules = hf_quant_config.get("ignored_layers")

# set exclude_quant_config
hf_ignored_quantization_config = hf_quant_config.get(
"ignored_quantization_config")
if hf_ignored_quantization_config:
quant_config.exclude_quant_config = {
"kv_cache_quant_algo":
QuantAlgo(
hf_ignored_quantization_config.get(
"kv_cache_quant_method").upper())
if hf_ignored_quantization_config.get("kv_cache_quant_method")
else None,
"activation_scheme":
ActivationScheme(
hf_ignored_quantization_config.get(
"activation_scheme").upper())
if hf_ignored_quantization_config.get("activation_scheme") else
None,
"group_size":
128,
}
if hf_ignored_quantization_config.get(
"quant_method"
) == "fp8" and hf_ignored_quantization_config.get(
"weight_block_size", []):
quant_config.exclude_quantization[
"quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES
block_size = hf_ignored_quantization_config.get(
"weight_block_size", [])
assert tuple(block_size) == (
128,
128), "FP8_BLOCK_SCALES only supports block_size=(128,128)"
quant_config.exclude_quantization["group_size"] = block_size[0]
elif hf_ignored_quantization_config.get("quant_method") == "fp8":
quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8
elif hf_ignored_quantization_config.get(
"quant_method") == "w4a8_awq":
quant_config.exclude_quantization[
"quant_algo"] = QuantAlgo.W4A8_AWQ
quant_config.exclude_quantization[
"group_size"] = hf_ignored_quantization_config.get(
"weight_group_size", 128)
else:
raise NotImplementedError(
f"Unsupported quantization_config.ignored_quantization_config: "
f"{hf_ignored_quantization_config}.")

logger.info(
f"Load quantization config from pretrained config, quant_config: {quant_config}"
)

return quant_config, layer_quant_config

Expand Down Expand Up @@ -484,6 +621,10 @@ def cached_file(path_or_repo_id, file_name):
'hf_quant_config.json'):
quant_config, layer_quant_config = cls.load_modelopt_quant_config(
quant_config_file, checkpoint_dir, moe_backend)
elif quant_config_file := cached_file(checkpoint_dir,
'angelslim_hf_quant_config.json'):
Copy link
Collaborator

@QiJune QiJune Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bppan , thanks for your contribution. Does all quantized checkpoint generated by angleslim tool contain angelslim_hf_quant_config.json? I check one example: https://huggingface.co/AngelSlim/DeepSeek-R1-0528_w4a8_fp8/blob/main/config.json

I only find config.json, and here is the content:

"quantization_config": {
  "quant_method": "w4a8_awq",
  "weight_group_size": 128,
  "activation_scheme": "static",
  "kv_cache_quant_method": "fp8",
  "ignored_layers": [
    "*self_attn*",
    "*gate_up_proj",
    "*down_proj",
    "*layers.61*"
  ],
  "ignored_quantization_config": {
    "quant_method": "fp8",
    "activation_scheme": "dynamic",
    "fmt": "e4m3",
    "kv_cache_quant_method": "fp8",
    "weight_block_size": [
      128,
      128
    ]
  }
},

In other words, will the following command work?

from tensorrt_llm import LLM
llm = LLM(model='AngelSlim/DeepSeek-R1-0528_w4a8_fp8')
llm.generate("Hello, my name is")

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

angleslim does not generate the file angelslim_hf_quant_config.json, when the file does not exist, the quantization_config field will be read from config.json. It is defined in the load_hf_quant_config method of the tensorrt_llm/_torch/model_config.py file.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will run the AngelSlim/DeepSeek-R1-0528_w4a8_fp8 model to verify if it works.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works!

Processed requests: 100%|██████████| 3/3 [00:19<00:00, 6.43s/it]
[0] Prompt: 'Hello, my name is', Generated text: " Dr. David Hill and today we're going to be talking about how to treat a child with a fever. Now, fever is one of the most common things that parents worry about. It's one of the most common reasons that children go to the doctor. But actually, fever is a good thing. Fever is a"
[1] Prompt: 'The capital of France is', Generated text: ' Paris. Paris is located in the north-central part of the country, along the Seine River. It is one of the most famous and visited cities in the world, known for its rich history, culture, art, fashion, and cuisine. Paris is also a major global center for diplomacy, commerce, education, science,'
[2] Prompt: 'The future of AI is', Generated text: ' a topic that has been discussed and debated for decades. With the rapid advancements in technology, it is becoming increasingly clear that artificial intelligence will play a significant role in shaping our future. From self-driving cars to virtual assistants, AI is already transforming the way we live and work. However, as with any new technology, there'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bppan , if I understand correctly,load_angelslim_quant_config will never be called. Then, let's remove it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I have removed it.

quant_config, layer_quant_config = cls.load_angelslim_quant_config(
quant_config_file, checkpoint_dir, moe_backend)
# quantized ckpt in other formats
elif hasattr(pretrained_config, "quantization_config"):
hf_quant_config = pretrained_config.quantization_config
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
if names[-1] == "kv_b_proj":
# TODO: remove weight_dequant after enabling fp8_bmm
dequant_kv_b_proj = self.model_config.quant_config.is_module_excluded_from_quantization(
names[-1])
names[-1]
) and self.model_config.quant_config.exclude_quantization is None
if dequant_kv_b_proj:
kv_b_proj, k_b_proj_trans = load_kv_b_proj_and_k_b_proj_trans_dequant(
name)
Expand Down
16 changes: 15 additions & 1 deletion tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,23 @@ def apply_quant_config_exclude_modules(self):
"""
quant_config = self.model_config.quant_config
kv_cache_quant_algo = None
quant_algo = None
activation_scheme = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is activation_scheme left as a placeholder for now or we have corresponding forward logic for this?

group_size = 128
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are quantization parameters like quant_algo and group_size necessary in the modules that skip quantization?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When exclude_quantization is present, the exclude module will use the quantization settings specified within exclude_quantization.

if quant_config:
kv_cache_quant_algo = quant_config.kv_cache_quant_algo
new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)
exclude_quantization = quant_config.exclude_quantization
if exclude_quantization:
quant_algo = exclude_quantization.get("quant_algo", None)
activation_scheme = exclude_quantization.get(
"activation_scheme", None)
group_size = exclude_quantization.get("group_size", 128)
new_config = QuantConfig(
quant_algo=quant_algo,
kv_cache_quant_algo=kv_cache_quant_algo,
activation_scheme=activation_scheme,
group_size=group_size,
)

if quant_config is not None:
if quant_config.exclude_modules is not None:
Expand Down
70 changes: 60 additions & 10 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,10 @@ def load_expert_weights_to_dst(
MoEWeightLoadingMode.VANILLA,
MoEWeightLoadingMode.W4A8_CUSTOM
]:
w1_weight = weights[f"{expert_id}.w1.weight"]
w3_weight = weights[f"{expert_id}.w3.weight"]
w2_weight = weights[f"{expert_id}.w2.weight"]
weight_name = "qweight" if f"{expert_id}.w1.qweight" in weights else "weight"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @Barry-Delaney @rosenrodt to help review this part.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think add another MoEWeightLoadingMode will be better. Also added in another comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If refactoring into different MoEWeightLoadingModes, please also be aware you would need to put extra flag during create_moe(). For DS example: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/models/modeling_deepseekv3.py#L840

self.experts = create_moe(..., weight_loading_mode=...)

w1_weight = weights[f"{expert_id}.w1.{weight_name}"]
w3_weight = weights[f"{expert_id}.w3.{weight_name}"]
w2_weight = weights[f"{expert_id}.w2.{weight_name}"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For changes in quantization.py, we should add respective tests in
tests/unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8 to avoid breaking this feature in the future

if module.bias:
w1_bias = weights[f"{expert_id}.w1.bias"]
w3_bias = weights[f"{expert_id}.w3.bias"]
Expand Down Expand Up @@ -1085,6 +1086,10 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
w4a8_custom = module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, MoEWeightLoadingMode.VANILLA is used for ModelOpt, and MoEWeightLoadingMode.W4A8_CUSTOM is used for the checkpoint produced by TRT-LLM scripts, is it okay to add MoEWeightLoadingMode.ANGELSIM or something similar to distinguish the following logics, or it's safe to reuse the TRT-LLM one?

if w4a8_custom:
weight_scale_name = "weight_scale_inv"
for expert_id in module.initial_local_expert_ids:
if f"{expert_id}.w3.weight_scale.int4" in weights:
weight_scale_name = "weight_scale.int4"
break
else:
weight_scale_name = "weight_scale"

Expand All @@ -1107,14 +1112,36 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
all_w3_w1_input_scales_max = torch.max(
torch.stack(all_w3_input_scales),
torch.stack(all_w1_input_scales)).max()
all_w3_w1_scales_fp8_max = None
has_fp8_weight_scale = False
if w4a8_custom:
# In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale
module.fc31_act_scale.data.copy_(
torch.ones_like(module.fc31_act_scale, device=self.device) *
(1 / all_w3_w1_input_scales_max))

for expert_id in module.initial_local_expert_ids:
if f"{expert_id}.w1.weight_scale" in weights:
has_fp8_weight_scale = True
break
if has_fp8_weight_scale:
all_w3_w1_scales_fp8_max = []
for expert_id in module.initial_local_expert_ids:
w1_weight_scale_fp8 = load_weight_shard(
weights[f"{expert_id}.w1.weight_scale"],
device=self.device)
w3_weight_scale_fp8 = load_weight_shard(
weights[f"{expert_id}.w3.weight_scale"],
device=self.device)
all_w3_w1_scales_fp8_max.append(
torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8))
all_w3_w1_scales_fp8_max = torch.stack(
all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape)
else:
all_w3_w1_scales_fp8_max = torch.ones_like(module.fc31_alpha,
device=self.device)
module.fc31_alpha.data.copy_(
(torch.ones_like(module.fc31_alpha, device=self.device) *
all_w3_w1_input_scales_max).float())
(all_w3_w1_scales_fp8_max * all_w3_w1_input_scales_max).float())
else:
# In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored
all_w3_pre_quant_scales = [
Expand Down Expand Up @@ -1192,6 +1219,9 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
if w4a8_custom and has_fp8_weight_scale:
all_w3_scales = torch.stack(
all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
all_w1_scales = [
load_weight_shard(weights[f"{expert_id}.w1.{weight_scale_name}"],
module.tp_size,
Expand All @@ -1200,9 +1230,15 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
torch.stack(all_w1_scales)], dim=-2)
if w4a8_custom and has_fp8_weight_scale:
all_w1_scales = torch.stack(
all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2)
all_w3_w1_scales = torch.cat([all_w3_scales, all_w1_scales], dim=-2)
else:
all_w3_w1_scales = torch.cat(
[torch.stack(all_w3_scales),
torch.stack(all_w1_scales)],
dim=-2)
if module.sm_version == 89:
w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype)
else:
Expand Down Expand Up @@ -1234,15 +1270,26 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
all_w2_input_scales_max = torch.stack(all_w2_input_scales).to(
module.dtype).max()

all_w2_scales_fp8 = None
if w4a8_custom:
# In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale
module.fc2_act_scale.data.copy_(
torch.ones_like(module.fc2_act_scale, device=self.device) *
(1 / all_w2_input_scales_max))
# In custom W4A8 ckpt, per-tensor weight_scale_2 is fused into alpha
if has_fp8_weight_scale:
all_w2_scales_fp8 = [
load_weight_shard(weights[f"{expert_id}.w2.weight_scale"],
device=self.device)
for expert_id in module.initial_local_expert_ids
]
all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape(
module.fc2_alpha.shape)
else:
all_w2_scales_fp8 = torch.ones_like(module.fc2_alpha,
device=self.device)
module.fc2_alpha.data.copy_(
(torch.ones_like(module.fc2_alpha, device=self.device) *
all_w2_input_scales_max).float())
(all_w2_scales_fp8 * all_w2_input_scales_max).float())
else:
# In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored
all_w2_pre_quant_scales = [
Expand Down Expand Up @@ -1288,6 +1335,9 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
device=self.device)
for expert_id in module.initial_local_expert_ids
]
if w4a8_custom and has_fp8_weight_scale:
all_w2_scales = torch.stack(
all_w2_scales) / all_w2_scales_fp8.unsqueeze(2)
if module.sm_version == 89:
w2_scales = torch.stack(all_w2_scales).to(torch.float16).view(
module.dtype)
Expand Down
Loading