From 98917b778781bc3bd168ca80f2e81bd126e6b756 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 10 Nov 2025 21:21:47 +0000 Subject: [PATCH 1/4] fix qparams decompression Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/base.py | 18 +++++++++++++++++- .../compressors/quantized_compressors/base.py | 12 +----------- .../quantized_compressors/fp4_quantized.py | 6 ++++++ .../quantized_compressors/pack_quantized.py | 2 ++ 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 3a8a97eb..d565de7c 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -20,6 +20,11 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig from compressed_tensors.registry import RegistryMixin from compressed_tensors.utils import has_offloaded_params +from compressed_tensors.utils.offload import ( + delete_offload_parameter, + get_offloaded_device, + register_offload_parameter, +) from torch import Tensor from torch.nn import Module @@ -185,10 +190,21 @@ def decompress_module(self, module: Module): for name, parameter in module.named_parameters(): compressed_data[name] = parameter - return self.decompress_weight( + result = self.decompress_weight( compressed_data=compressed_data, quantization_args=quantization_args ).to(device) + # Update module's parameters if they were unpacked/upcast during decompression + for param_name in ["weight_zero_point", "weight_scale"]: + if param_name in compressed_data and hasattr(module, param_name): + # Delete the old parameter and register the updated one + delete_offload_parameter(module, param_name) + offload_device = get_offloaded_device(module) + param = torch.nn.Parameter(compressed_data[param_name], requires_grad=False) + register_offload_parameter(module, param_name, param, offload_device) + + return result + def decompress_weight( self, compressed_data: Dict[str, Tensor], **kwargs ) -> torch.Tensor: diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 19f6c9c0..ad69708d 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -155,17 +155,7 @@ def _skip_zp( if zp_name == "output_zero_point": args = scheme.output_activations - symmetric = args.symmetric - packable_strategies = [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.CHANNEL.value, - ] - packed = ( - isinstance(self, PackedQuantizationCompressor) - and args.strategy in packable_strategies - ) - - return symmetric or packed + return args.symmetric def decompress( self, diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index b9bd8ede..7cdcc981 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -117,6 +117,12 @@ def decompress_weight( m, n = weight.shape # TODO: use a user provided dequant dtype unpacked = unpack_fp4_from_uint8(weight, m, n * 2) + + # cast scale dtype to match unpacked dtype for dequantization + if scale.dtype != unpacked.dtype: + scale = scale.to(unpacked.dtype) + compressed_data["weight_scale"] = scale + decompressed_weight = dequantize( x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype ) diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 07da50c7..c41a38c8 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -175,6 +175,8 @@ def decompress_weight( zero_point = unpack_from_int32( zero_point, num_bits, original_zp_shape, packed_dim=0 ) + # Update the compressed_data dict with the unpacked zero_point + compressed_data["weight_zero_point"] = zero_point decompressed_weight = dequantize( x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx From 70e08389523016afddffffdfa0da6618e78a7822 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 10 Nov 2025 21:29:28 +0000 Subject: [PATCH 2/4] quality Signed-off-by: shanjiaz --- src/compressed_tensors/compressors/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index d565de7c..be607f4f 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -200,7 +200,9 @@ def decompress_module(self, module: Module): # Delete the old parameter and register the updated one delete_offload_parameter(module, param_name) offload_device = get_offloaded_device(module) - param = torch.nn.Parameter(compressed_data[param_name], requires_grad=False) + param = torch.nn.Parameter( + compressed_data[param_name], requires_grad=False + ) register_offload_parameter(module, param_name, param, offload_device) return result From e7c8becbc48cdae4c0a0bffdc4edad2b1d6fa973 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Mon, 10 Nov 2025 21:35:42 +0000 Subject: [PATCH 3/4] quality Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index ad69708d..7a8f201c 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -18,7 +18,7 @@ import torch from compressed_tensors.compressors.base import BaseCompressor -from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy +from compressed_tensors.quantization import QuantizationScheme from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, get_nested_weight_mappings, @@ -143,8 +143,6 @@ def _skip_scale(self): def _skip_zp( self, name: str, names_to_scheme: Dict[str, QuantizationScheme] ) -> bool: - from compressed_tensors.compressors import PackedQuantizationCompressor - module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name) scheme = names_to_scheme[module_name] From ac326eede1b37d941b49688d01d32d3a54628bb1 Mon Sep 17 00:00:00 2001 From: shanjiaz Date: Tue, 18 Nov 2025 19:38:32 +0000 Subject: [PATCH 4/4] Add zero-point compression for asymmetric quantization Signed-off-by: shanjiaz --- .../compressors/quantized_compressors/base.py | 33 +++++++++---------- .../quantized_compressors/pack_quantized.py | 16 +++++++++ 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 7a8f201c..71913e6f 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -124,9 +124,21 @@ def compress( compressed_dict[prefix + key] = value.to(compression_device) else: - # omit saving zero points for symmetric or packed quantization - if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme): - continue + # omit saving zero points for symmetric quantization + if name.endswith("weight_zero_point"): + module_path = name.rsplit(".", 1)[0] + if ( + module_path in names_to_scheme + and names_to_scheme[module_path].weights.symmetric + ): + continue + # Call compress_zp if available (for PackedQuantizationCompressor) + if module_path in names_to_scheme and hasattr(self, "compress_zp"): + value = self.compress_zp( + value, names_to_scheme[module_path].weights + ) + if value is None: + continue if name.endswith("weight_scale") and self._skip_scale(): continue @@ -140,21 +152,6 @@ def _skip_scale(self): return isinstance(self, NVFP4PackedCompressor) - def _skip_zp( - self, name: str, names_to_scheme: Dict[str, QuantizationScheme] - ) -> bool: - module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name) - scheme = names_to_scheme[module_name] - - if zp_name == "weight_zero_point": - args = scheme.weights - if zp_name == "input_zero_point": - args = scheme.input_activations - if zp_name == "output_zero_point": - args = scheme.output_activations - - return args.symmetric - def decompress( self, path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index c41a38c8..d8560f3a 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -184,6 +184,22 @@ def decompress_weight( return decompressed_weight + def compress_zp( + self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None + ) -> Optional[Tensor]: + if zero_point is None or quantization_args.symmetric: + return None + if zero_point.dtype == torch.int32: + return zero_point + if quantization_args.strategy in [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.CHANNEL.value, + ]: + return pack_to_int32( + zero_point, quantization_args.num_bits, packed_dim=0 + ).contiguous() + return zero_point + def pack_to_int32( value: torch.Tensor,