-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Support W4A8 method of AngleSlim tool #6857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
5f02fd0
28bb653
438d510
a6508b8
6880e91
ddfb0a8
a7fdcff
3d8a428
ba95dee
e72c02a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,7 @@ | |
| from tensorrt_llm.logger import logger | ||
| from tensorrt_llm.mapping import Mapping | ||
| from tensorrt_llm.models.modeling_utils import QuantConfig | ||
| from tensorrt_llm.quantization.mode import QuantAlgo | ||
| from tensorrt_llm.quantization.mode import ActivationScheme, QuantAlgo | ||
|
|
||
| TConfig = TypeVar("TConfig", bound=transformers.PretrainedConfig) | ||
|
|
||
|
|
@@ -324,6 +324,70 @@ def load_modelopt_quant_config(quant_config_file, checkpoint_dir, | |
| ] | ||
| return quant_config, layer_quant_config | ||
|
|
||
| @staticmethod | ||
| def load_angelslim_quant_config(quant_config_file, checkpoint_dir, | ||
| moe_backend): | ||
| quant_config = QuantConfig() | ||
| layer_quant_config = None | ||
|
|
||
| with open(quant_config_file) as f: | ||
| quant_config_dict = json.load(f) | ||
|
|
||
| json_quant_configs = quant_config_dict['quantization'] | ||
|
|
||
| quant_config.quant_algo = QuantAlgo( | ||
| json_quant_configs.get( | ||
| 'quant_algo', | ||
| None).upper()) if json_quant_configs.get("quant_algo") else None | ||
| # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES | ||
| if quant_config.quant_algo == "fp8_pb_wo": | ||
| quant_config.quant_algo = QuantAlgo('FP8_BLOCK_SCALES') | ||
|
|
||
| quant_config.kv_cache_quant_algo = QuantAlgo( | ||
| json_quant_configs.get("kv_cache_quant_algo").upper() | ||
| ) if json_quant_configs.get("kv_cache_quant_algo") else None | ||
| quant_config.group_size = json_quant_configs.get('group_size', None) | ||
| quant_config.exclude_modules = json_quant_configs.get( | ||
| 'exclude_modules', None) | ||
| quant_config.activation_scheme = ActivationScheme( | ||
| json_quant_configs.get('activation_scheme', None).upper() | ||
| ) if json_quant_configs.get("activation_scheme") else None | ||
|
|
||
| json_exclude_quantization = json_quant_configs.get( | ||
| 'exclude_quantization', None) | ||
| if json_exclude_quantization: | ||
| quant_config.exclude_quant_config = { | ||
| "quant_algo": | ||
| QuantAlgo( | ||
| json_exclude_quantization.get('quant_algo', None).upper()) | ||
| if json_exclude_quantization.get("quant_algo") else None, | ||
| "kv_cache_quant_algo": | ||
| QuantAlgo( | ||
| json_exclude_quantization.get( | ||
| "kv_cache_quant_algo").upper()) if | ||
| json_exclude_quantization.get("kv_cache_quant_algo") else None, | ||
| "activation_scheme": | ||
| ActivationScheme( | ||
| json_exclude_quantization.get('activation_scheme', | ||
| None).upper()) | ||
| if json_exclude_quantization.get("activation_scheme") else None, | ||
| "group_size": | ||
| json_exclude_quantization.get('group_size', None), | ||
| } | ||
| if quant_config.exclude_quantization["quant_algo"] in [ | ||
| QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ | ||
| ]: | ||
| if quant_config.exclude_quantization["group_size"] is None: | ||
| quant_config.exclude_quantization["group_size"] = 128 | ||
|
|
||
| if quant_config.quant_algo in [ | ||
| QuantAlgo.FP8_BLOCK_SCALES, QuantAlgo.W4A8_AWQ | ||
| ]: | ||
| if quant_config.group_size is None: | ||
| quant_config.group_size = 128 | ||
|
|
||
| return quant_config, layer_quant_config | ||
|
|
||
| @staticmethod | ||
| def get_mxfp4_quant_algo(moe_backend, is_dynamic_quant=False): | ||
| quant_algo = ModelConfig.override_quant_algo() | ||
|
|
@@ -368,6 +432,79 @@ def load_hf_quant_config(hf_quant_config, moe_backend): | |
| 'block.*.attn.out', 'block.*.mlp.gate', 'block.*.attn.qkv', | ||
| 'embedding', 'unembedding' | ||
| ] | ||
| elif hf_quant_config.get("quant_method") == "fp8": | ||
| quant_config.quant_algo = QuantAlgo.FP8 | ||
| elif hf_quant_config.get("quant_method") == "w4a8_awq": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add some comments like "AngleSlim w4a8_awq ckpt"? |
||
| quant_config.quant_algo = QuantAlgo.W4A8_AWQ | ||
| quant_config.group_size = hf_quant_config.get( | ||
| "weight_group_size", 128) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Unsupported quantization_config: {hf_quant_config}.") | ||
|
|
||
| # set kv_cache_quant_algo | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please make the following codes into a single function, something like # DeepSeek V3 FP8 ckpt
if hf_quant_config.get("quant_method") == "fp8" xxx:
xxx
# MXFP4 checkpoints.
elif hf_quant_config.get("quant_method") == "mxfp4":
xxx
# Angleslim FP8 checkpoint
elif hf_quant_config.get("quant_method") == "fp8" :
xxx
# Angleslim w4a8_awq checkpoint
elif hf_quant_config.get("quant_method") == "w4a8_awq":
quant_config = load_angleslim_config(xxx)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My PR refactors the codebase of how create the quant_config |
||
| quant_config.kv_cache_quant_algo = QuantAlgo(hf_quant_config.get("kv_cache_quant_method").upper()) \ | ||
| if hf_quant_config.get("kv_cache_quant_method") else None | ||
| # set activation_scheme | ||
| quant_config.activation_scheme = ActivationScheme(hf_quant_config.get("activation_scheme").upper()) \ | ||
| if hf_quant_config.get("activation_scheme") else None | ||
| # set exclude_modules | ||
| if quant_config.exclude_modules: | ||
| if hf_quant_config.get("ignored_layers"): | ||
| quant_config.exclude_modules += hf_quant_config.get( | ||
| "ignored_layers") | ||
| else: | ||
| quant_config.exclude_modules = hf_quant_config.get("ignored_layers") | ||
|
|
||
| # set exclude_quant_config | ||
| hf_ignored_quantization_config = hf_quant_config.get( | ||
| "ignored_quantization_config") | ||
| if hf_ignored_quantization_config: | ||
| quant_config.exclude_quant_config = { | ||
| "kv_cache_quant_algo": | ||
| QuantAlgo( | ||
| hf_ignored_quantization_config.get( | ||
| "kv_cache_quant_method").upper()) | ||
| if hf_ignored_quantization_config.get("kv_cache_quant_method") | ||
| else None, | ||
| "activation_scheme": | ||
| ActivationScheme( | ||
| hf_ignored_quantization_config.get( | ||
| "activation_scheme").upper()) | ||
| if hf_ignored_quantization_config.get("activation_scheme") else | ||
| None, | ||
| "group_size": | ||
| 128, | ||
| } | ||
| if hf_ignored_quantization_config.get( | ||
| "quant_method" | ||
| ) == "fp8" and hf_ignored_quantization_config.get( | ||
| "weight_block_size", []): | ||
| quant_config.exclude_quantization[ | ||
| "quant_algo"] = QuantAlgo.FP8_BLOCK_SCALES | ||
| block_size = hf_ignored_quantization_config.get( | ||
| "weight_block_size", []) | ||
| assert tuple(block_size) == ( | ||
| 128, | ||
| 128), "FP8_BLOCK_SCALES only supports block_size=(128,128)" | ||
| quant_config.exclude_quantization["group_size"] = block_size[0] | ||
| elif hf_ignored_quantization_config.get("quant_method") == "fp8": | ||
| quant_config.exclude_quantization["quant_algo"] = QuantAlgo.FP8 | ||
| elif hf_ignored_quantization_config.get( | ||
| "quant_method") == "w4a8_awq": | ||
| quant_config.exclude_quantization[ | ||
| "quant_algo"] = QuantAlgo.W4A8_AWQ | ||
| quant_config.exclude_quantization[ | ||
| "group_size"] = hf_ignored_quantization_config.get( | ||
| "weight_group_size", 128) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Unsupported quantization_config.ignored_quantization_config: " | ||
| f"{hf_ignored_quantization_config}.") | ||
|
|
||
| logger.info( | ||
| f"Load quantization config from pretrained config, quant_config: {quant_config}" | ||
| ) | ||
|
|
||
| return quant_config, layer_quant_config | ||
|
|
||
|
|
@@ -484,6 +621,10 @@ def cached_file(path_or_repo_id, file_name): | |
| 'hf_quant_config.json'): | ||
| quant_config, layer_quant_config = cls.load_modelopt_quant_config( | ||
| quant_config_file, checkpoint_dir, moe_backend) | ||
| elif quant_config_file := cached_file(checkpoint_dir, | ||
| 'angelslim_hf_quant_config.json'): | ||
|
||
| quant_config, layer_quant_config = cls.load_angelslim_quant_config( | ||
| quant_config_file, checkpoint_dir, moe_backend) | ||
| # quantized ckpt in other formats | ||
| elif hasattr(pretrained_config, "quantization_config"): | ||
| hf_quant_config = pretrained_config.quantization_config | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -478,9 +478,23 @@ def apply_quant_config_exclude_modules(self): | |
| """ | ||
| quant_config = self.model_config.quant_config | ||
| kv_cache_quant_algo = None | ||
| quant_algo = None | ||
| activation_scheme = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is |
||
| group_size = 128 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are quantization parameters like quant_algo and group_size necessary in the modules that skip quantization?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When |
||
| if quant_config: | ||
| kv_cache_quant_algo = quant_config.kv_cache_quant_algo | ||
| new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo) | ||
| exclude_quantization = quant_config.exclude_quantization | ||
| if exclude_quantization: | ||
| quant_algo = exclude_quantization.get("quant_algo", None) | ||
| activation_scheme = exclude_quantization.get( | ||
| "activation_scheme", None) | ||
| group_size = exclude_quantization.get("group_size", 128) | ||
| new_config = QuantConfig( | ||
| quant_algo=quant_algo, | ||
| kv_cache_quant_algo=kv_cache_quant_algo, | ||
| activation_scheme=activation_scheme, | ||
| group_size=group_size, | ||
| ) | ||
|
|
||
| if quant_config is not None: | ||
| if quant_config.exclude_modules is not None: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -224,9 +224,10 @@ def load_expert_weights_to_dst( | |
| MoEWeightLoadingMode.VANILLA, | ||
| MoEWeightLoadingMode.W4A8_CUSTOM | ||
| ]: | ||
| w1_weight = weights[f"{expert_id}.w1.weight"] | ||
| w3_weight = weights[f"{expert_id}.w3.weight"] | ||
| w2_weight = weights[f"{expert_id}.w2.weight"] | ||
| weight_name = "qweight" if f"{expert_id}.w1.qweight" in weights else "weight" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @Barry-Delaney @rosenrodt to help review this part.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think add another
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If refactoring into different self.experts = create_moe(..., weight_loading_mode=...) |
||
| w1_weight = weights[f"{expert_id}.w1.{weight_name}"] | ||
| w3_weight = weights[f"{expert_id}.w3.{weight_name}"] | ||
| w2_weight = weights[f"{expert_id}.w2.{weight_name}"] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For changes in |
||
| if module.bias: | ||
| w1_bias = weights[f"{expert_id}.w1.bias"] | ||
| w3_bias = weights[f"{expert_id}.w3.bias"] | ||
|
|
@@ -1085,6 +1086,10 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| w4a8_custom = module.weight_loading_mode == MoEWeightLoadingMode.W4A8_CUSTOM | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, |
||
| if w4a8_custom: | ||
| weight_scale_name = "weight_scale_inv" | ||
| for expert_id in module.initial_local_expert_ids: | ||
bppan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if f"{expert_id}.w3.weight_scale.int4" in weights: | ||
| weight_scale_name = "weight_scale.int4" | ||
| break | ||
| else: | ||
| weight_scale_name = "weight_scale" | ||
|
|
||
|
|
@@ -1107,14 +1112,36 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| all_w3_w1_input_scales_max = torch.max( | ||
| torch.stack(all_w3_input_scales), | ||
| torch.stack(all_w1_input_scales)).max() | ||
| all_w3_w1_scales_fp8_max = None | ||
| has_fp8_weight_scale = False | ||
| if w4a8_custom: | ||
| # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale | ||
| module.fc31_act_scale.data.copy_( | ||
| torch.ones_like(module.fc31_act_scale, device=self.device) * | ||
| (1 / all_w3_w1_input_scales_max)) | ||
|
|
||
| for expert_id in module.initial_local_expert_ids: | ||
bppan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if f"{expert_id}.w1.weight_scale" in weights: | ||
| has_fp8_weight_scale = True | ||
| break | ||
| if has_fp8_weight_scale: | ||
| all_w3_w1_scales_fp8_max = [] | ||
| for expert_id in module.initial_local_expert_ids: | ||
| w1_weight_scale_fp8 = load_weight_shard( | ||
| weights[f"{expert_id}.w1.weight_scale"], | ||
| device=self.device) | ||
| w3_weight_scale_fp8 = load_weight_shard( | ||
| weights[f"{expert_id}.w3.weight_scale"], | ||
| device=self.device) | ||
| all_w3_w1_scales_fp8_max.append( | ||
| torch.max(w3_weight_scale_fp8, w1_weight_scale_fp8)) | ||
| all_w3_w1_scales_fp8_max = torch.stack( | ||
| all_w3_w1_scales_fp8_max).reshape(module.fc31_alpha.shape) | ||
| else: | ||
| all_w3_w1_scales_fp8_max = torch.ones_like(module.fc31_alpha, | ||
| device=self.device) | ||
| module.fc31_alpha.data.copy_( | ||
| (torch.ones_like(module.fc31_alpha, device=self.device) * | ||
| all_w3_w1_input_scales_max).float()) | ||
| (all_w3_w1_scales_fp8_max * all_w3_w1_input_scales_max).float()) | ||
| else: | ||
| # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored | ||
| all_w3_pre_quant_scales = [ | ||
|
|
@@ -1192,6 +1219,9 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| device=self.device) | ||
| for expert_id in module.initial_local_expert_ids | ||
| ] | ||
| if w4a8_custom and has_fp8_weight_scale: | ||
| all_w3_scales = torch.stack( | ||
| all_w3_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) | ||
| all_w1_scales = [ | ||
| load_weight_shard(weights[f"{expert_id}.w1.{weight_scale_name}"], | ||
| module.tp_size, | ||
|
|
@@ -1200,9 +1230,15 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| device=self.device) | ||
| for expert_id in module.initial_local_expert_ids | ||
| ] | ||
| all_w3_w1_scales = torch.cat( | ||
| [torch.stack(all_w3_scales), | ||
| torch.stack(all_w1_scales)], dim=-2) | ||
| if w4a8_custom and has_fp8_weight_scale: | ||
| all_w1_scales = torch.stack( | ||
| all_w1_scales) / all_w3_w1_scales_fp8_max.unsqueeze(2) | ||
| all_w3_w1_scales = torch.cat([all_w3_scales, all_w1_scales], dim=-2) | ||
| else: | ||
| all_w3_w1_scales = torch.cat( | ||
| [torch.stack(all_w3_scales), | ||
| torch.stack(all_w1_scales)], | ||
| dim=-2) | ||
| if module.sm_version == 89: | ||
| w3_w1_scales = all_w3_w1_scales.to(torch.float16).view(module.dtype) | ||
| else: | ||
|
|
@@ -1234,15 +1270,26 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| all_w2_input_scales_max = torch.stack(all_w2_input_scales).to( | ||
| module.dtype).max() | ||
|
|
||
| all_w2_scales_fp8 = None | ||
| if w4a8_custom: | ||
| # In custom W4A8 ckpt, per-tensor input_scale and per-channel pre_quant_scale are fused into input_scale | ||
| module.fc2_act_scale.data.copy_( | ||
| torch.ones_like(module.fc2_act_scale, device=self.device) * | ||
| (1 / all_w2_input_scales_max)) | ||
| # In custom W4A8 ckpt, per-tensor weight_scale_2 is fused into alpha | ||
| if has_fp8_weight_scale: | ||
| all_w2_scales_fp8 = [ | ||
| load_weight_shard(weights[f"{expert_id}.w2.weight_scale"], | ||
| device=self.device) | ||
| for expert_id in module.initial_local_expert_ids | ||
| ] | ||
| all_w2_scales_fp8 = torch.stack(all_w2_scales_fp8).reshape( | ||
| module.fc2_alpha.shape) | ||
| else: | ||
| all_w2_scales_fp8 = torch.ones_like(module.fc2_alpha, | ||
| device=self.device) | ||
| module.fc2_alpha.data.copy_( | ||
| (torch.ones_like(module.fc2_alpha, device=self.device) * | ||
| all_w2_input_scales_max).float()) | ||
| (all_w2_scales_fp8 * all_w2_input_scales_max).float()) | ||
| else: | ||
| # In vanilla ckpt (at least from ModelOpt), per-tensor input_scale and per-channel pre_quant_scale are separately stored | ||
| all_w2_pre_quant_scales = [ | ||
|
|
@@ -1288,6 +1335,9 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): | |
| device=self.device) | ||
| for expert_id in module.initial_local_expert_ids | ||
| ] | ||
| if w4a8_custom and has_fp8_weight_scale: | ||
| all_w2_scales = torch.stack( | ||
| all_w2_scales) / all_w2_scales_fp8.unsqueeze(2) | ||
| if module.sm_version == 89: | ||
| w2_scales = torch.stack(all_w2_scales).to(torch.float16).view( | ||
| module.dtype) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add some comments like "AngleSlim fp8 ckpt"? A reference:
TensorRT-LLM/tensorrt_llm/_torch/model_config.py
Line 346 in 67af7c1