3131 SPARSITY_CONFIG_NAME ,
3232)
3333from compressed_tensors .compressors .base import BaseCompressor
34+ from compressed_tensors .compressors .sparse_compressors import DenseCompressor
3435from compressed_tensors .config import CompressionFormat , SparsityCompressionConfig
3536from compressed_tensors .quantization import (
3637 DEFAULT_QUANTIZATION_METHOD ,
3738 QuantizationConfig ,
3839 QuantizationStatus ,
3940 apply_quantization_config ,
40- load_pretrained_quantization ,
41+ load_pretrained_quantization_parameters ,
4142)
4243from compressed_tensors .quantization .lifecycle import expand_target_names
4344from compressed_tensors .quantization .quant_args import QuantizationArgs
4748)
4849from compressed_tensors .utils import (
4950 get_safetensors_folder ,
51+ has_offloaded_params ,
5052 merge_names ,
53+ register_offload_parameter ,
5154 update_parameter_data ,
5255)
5356from compressed_tensors .utils .helpers import (
@@ -412,6 +415,13 @@ def decompress(self, model_path: str, model: Module):
412415
413416 :param model_path: path to compressed weights
414417 :param model: pytorch model to load decompressed weights into
418+
419+ Note: decompress makes use of both _replace_sparsity_weights and _replace_weights
420+ The variations in these methods are a result of the subtle variations between the sparsity
421+ and quantization compressors. Specifically, quantization compressors return not just the
422+ decompressed weight, but the quantization parameters (e.g scales, zero_point) whereas sparsity
423+ compressors only return the decompressed weight.
424+
415425 """
416426 model_path = get_safetensors_folder (model_path )
417427 sparse_decompressed = False
@@ -420,9 +430,16 @@ def decompress(self, model_path: str, model: Module):
420430 self .sparsity_compressor is not None
421431 and self .sparsity_config .format != CompressionFormat .dense .value
422432 ):
433+ params_to_ignore = None
434+ if self .quantization_compressor is not None :
435+ params_to_ignore = self .quantization_compressor .compression_param_names
423436 # Sparse decompression is applied on the model_path
424- dense_gen = self .sparsity_compressor .decompress (model_path )
425- self ._replace_weights (dense_gen , model )
437+ # The compressor will try and load any quantization parameters as well
438+ # params_to_skip_load will skip over quantization params from being loaded
439+ dense_gen = self .sparsity_compressor .decompress (
440+ model_path , params_to_skip_load = params_to_ignore
441+ )
442+ self ._replace_sparsity_weights (dense_gen , model )
426443 setattr (model , SPARSITY_CONFIG_NAME , self .sparsity_compressor .config )
427444 sparse_decompressed = True
428445
@@ -431,13 +448,27 @@ def decompress(self, model_path: str, model: Module):
431448 # quantization during apply_quantization_config. This ensures
432449 # that the dtypes of the weights are not unintentionally updated.
433450 # The status is restored after quantization params are loaded.
451+
434452 with override_quantization_status (
435453 self .quantization_config , QuantizationStatus .FROZEN
436454 ):
455+
437456 names_to_scheme = apply_quantization_config (
438457 model , self .quantization_config
439458 )
440- load_pretrained_quantization (model , model_path )
459+ # Load activation scales/zp or any other quantization parameters
460+ # Conditionally load the weight quantization parameters if we have a dense compressor
461+ # Or if a sparsity compressor has already been applied
462+ load_pretrained_quantization_parameters (
463+ model ,
464+ model_path ,
465+ # TODO: all weight quantization params will be moved to the compressor in a follow-up
466+ # including initialization
467+ load_weight_quantization = (
468+ sparse_decompressed
469+ or isinstance (self .quantization_compressor , DenseCompressor )
470+ ),
471+ )
441472
442473 model_path_or_state_dict = (
443474 model .state_dict () if sparse_decompressed else model_path
@@ -446,6 +477,8 @@ def decompress(self, model_path: str, model: Module):
446477 dense_gen = self .quantization_compressor .decompress (
447478 model_path_or_state_dict , names_to_scheme = names_to_scheme
448479 )
480+ # TODO: all weight quantization params will be moved to the compressor
481+ # to prevent duplicate parameter updates in update_parameter_data
449482 self ._replace_weights (dense_gen , model )
450483
451484 def freeze_quantization_status (module ):
@@ -501,7 +534,7 @@ def update_config(self, save_directory: str):
501534 with open (config_file_path , "w" ) as config_file :
502535 json .dump (config_data , config_file , indent = 2 , sort_keys = True )
503536
504- def _replace_weights (self , dense_weight_generator , model : Module ):
537+ def _replace_sparsity_weights (self , dense_weight_generator , model : Module ):
505538 """
506539 Replace the weights of the model with the
507540 provided dense weights.
@@ -516,11 +549,60 @@ def _replace_weights(self, dense_weight_generator, model: Module):
516549 :param model: The model whose weights are to be updated.
517550 """
518551 for name , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
552+
519553 split_name = name .split ("." )
520554 prefix , param_name = "." .join (split_name [:- 1 ]), split_name [- 1 ]
521555 module = operator .attrgetter (prefix )(model )
522- if hasattr (module , param_name ):
523- update_parameter_data (module , data , param_name )
556+
557+ params_device = next (module .parameters ()).device
558+ device = "cpu" if has_offloaded_params (module ) else params_device
559+ delattr (module , param_name )
560+ requires_grad = data .dtype in (torch .float16 , torch .float32 , torch .bfloat16 )
561+ param = torch .nn .Parameter (data .to (device ), requires_grad = requires_grad )
562+ register_offload_parameter (module , param_name , param )
563+
564+ def _replace_weights (self , dense_weight_generator , model : Module ):
565+ """
566+ Replace the weights of the model with the
567+ provided dense weights.
568+
569+ This method iterates over the dense_weight_generator and
570+ updates the corresponding weights in the model. If a parameter
571+ name does not exist in the model, it will be skipped.
572+
573+ :param dense_weight_generator (generator): A generator that yields
574+ tuples of (name, data), where 'name' is the parameter name and
575+ 'data' is the updated param data
576+ :param model: The model whose weights are to be updated.
577+ """
578+
579+ for name , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
580+ module = operator .attrgetter (name )(model )
581+
582+ params_device = next (module .parameters ()).device
583+ device = "cpu" if has_offloaded_params (module ) else params_device
584+
585+ for param_name , param_data in data .items ():
586+ if hasattr (module , param_name ):
587+ # If compressed, will have an incorrect dtype for transformers >4.49
588+ # TODO: we can also just skip initialization of scales/zp if in decompression in init
589+ # to be consistent with loading which happens later as well
590+ # however, update_data does a good shape check - should be moved to the compressor
591+ if param_name == "weight" :
592+ delattr (module , param_name )
593+ requires_grad = param_data .dtype in (
594+ torch .float16 ,
595+ torch .float32 ,
596+ torch .bfloat16 ,
597+ )
598+ param = torch .nn .Parameter (
599+ param_data .to (device ), requires_grad = requires_grad
600+ )
601+ register_offload_parameter (module , param_name , param )
602+ else :
603+ # Should already be registered to the correct device for
604+ # for scales/zero-points
605+ update_parameter_data (module , param_data , param_name )
524606
525607
526608def map_modules_to_quant_args (
0 commit comments