1515 preset_name_to_scheme ,
1616)
1717from compressed_tensors .utils import match_named_modules
18- from pydantic import Field , PrivateAttr , field_validator , model_validator
18+ from pydantic import Field , PrivateAttr , field_validator
1919from torch .utils .hooks import RemovableHandle
2020
2121from llmcompressor .modifiers .quantization .calibration import (
@@ -62,6 +62,9 @@ class QuantizationMixin(HooksMixin):
6262 :param targets: list of layer names to quantize if a scheme is provided. If unset,
6363 will contain all targets listed in config_groups. If config_groups is also
6464 unset, will default to ["Linear"] (i.e. all Linear layers will be targeted).
65+ This field is not the source of truth for finding all matching target layers
66+ in a model. Additional information can be stored in `config_groups`. Use
67+ self.resolved_targets instead.
6568 :param ignore: optional list of module class names or submodule names to not
6669 quantize even if they match a target in config_groups. Defaults to empty list.
6770 :param scheme: a single quantization scheme to apply to the model. This is a
@@ -83,12 +86,16 @@ class QuantizationMixin(HooksMixin):
8386 """
8487
8588 config_groups : Optional [Dict [str , QuantizationScheme ]] = None
86- targets : Union [str , List [str ]] = Field (default_factory = list )
89+ # NOTE: targets is not the sole source of truth for finding all matching target
90+ # layers in a model. Additional information can be stored in `config_groups`
91+ # Use self.resolved_targets as source of truth.
92+ targets : Union [str , List [str ]] = Field (default_factory = lambda : ["Linear" ])
8793 ignore : List [str ] = Field (default_factory = list )
8894 scheme : Optional [Union [str , Dict [str , Any ]]] = None
8995 kv_cache_scheme : Optional [QuantizationArgs ] = None
9096
9197 _calibration_hooks : Set [RemovableHandle ] = PrivateAttr (default_factory = set )
98+ _resolved_config : Optional [QuantizationConfig ] = PrivateAttr (None )
9299
93100 @field_validator ("targets" , mode = "before" )
94101 def validate_targets (cls , value : Union [str , List [str ]]) -> List [str ]:
@@ -116,27 +123,29 @@ def validate_scheme(
116123
117124 return value
118125
119- @model_validator ( mode = "after" )
120- def validate_model_after ( model : "QuantizationMixin" ) -> "QuantizationMixin" :
126+ @property
127+ def resolved_config ( self ) -> QuantizationConfig :
121128 """
122- - If targets have not been set, aggregate targets from config_groups
123- into a single unique list
124- - If targets have still not been found, default to targets=["Linear"]
129+ Quantization config needs to be resolved just once based on
130+ scheme and config_groups inputs.
125131 """
132+ if self ._resolved_config is None :
133+ self ._resolved_config = self .resolve_quantization_config ()
134+ return self ._resolved_config
126135
127- if len ( model . targets ) > 0 and model . config_groups is not None :
128- raise ValueError ( "Please specify either `targets` or `config_groups`" )
129-
130- if len ( model . targets ) == 0 and model . config_groups is not None :
131- for config_group in model . config_groups . values ():
132- for target in config_group . targets :
133- if target not in model . targets :
134- model . targets . append ( target )
135-
136- if len ( model . targets ) == 0 :
137- model .targets . append ( "Linear" )
138-
139- return model
136+ @ property
137+ def resolved_targets ( self ) -> Set [ str ]:
138+ """
139+ Set of all resolved targets, i.e. all unique targets listed
140+ in resolved quantization config.
141+ Use this property instead of the targets field, as targets can
142+ also come from config_groups depending on how recipe is configured.
143+ """
144+ targets = set ()
145+ for config_group in self . resolved_config . config_groups . values () :
146+ for target in config_group .targets :
147+ targets . add ( target )
148+ return targets
140149
141150 def initialize_quantization (self , model : torch .nn .Module ):
142151 """
@@ -145,13 +154,11 @@ def initialize_quantization(self, model: torch.nn.Module):
145154
146155 :param model: model to attach schemes and observers to
147156 """
148- # apply scheme and status to model
149- config = self .resolve_quantization_config ()
150157
151- for _ , module in match_named_modules (model , self .targets , self .ignore ):
158+ for _ , module in match_named_modules (model , self .resolved_targets , self .ignore ):
152159 reset_quantization_status (module ) # reset any previously applied qconfigs
153160
154- apply_quantization_config (model , config )
161+ apply_quantization_config (model , self . resolved_config )
155162
156163 # disable quantization until calibration
157164 model .apply (disable_quantization )
@@ -164,7 +171,7 @@ def start_calibration(self, model: torch.nn.Module):
164171 :param model: model to prepare for calibration
165172 """
166173 self ._calibration_hooks = self ._initialize_hooks (model )
167- for _ , module in match_named_modules (model , self .targets , self .ignore ):
174+ for _ , module in match_named_modules (model , self .resolved_targets , self .ignore ):
168175 self ._initialize_observers (module )
169176 apply_calibration_status (module )
170177
@@ -178,7 +185,7 @@ def end_calibration(self, model: torch.nn.Module):
178185 :param model: model to end calibration for
179186 """
180187 self .remove_hooks (self ._calibration_hooks )
181- for _ , module in match_named_modules (model , self .targets , self .ignore ):
188+ for _ , module in match_named_modules (model , self .resolved_targets , self .ignore ):
182189 freeze_module_quantization (module ) # remove observers
183190
184191 model .apply (enable_quantization ) # keep quantization enabled
@@ -270,7 +277,7 @@ def _initialize_observers(self, module: torch.nn.Module):
270277
271278 def _initialize_hooks (self , model : torch .nn .Module ) -> Set [RemovableHandle ]:
272279 hooks = set ()
273- for _ , module in match_named_modules (model , self .targets , self .ignore ):
280+ for _ , module in match_named_modules (model , self .resolved_targets , self .ignore ):
274281 if not hasattr (module , "quantization_scheme" ):
275282 continue
276283
0 commit comments