Skip to content

Commit 7194d4b

Browse files
[Bugfix] Refactor QuantizationMixin to use resolved config (#1912)
SUMMARY: Fixes #1906 This refactors QuantizationMixin to not update any pydantic fields during validation. Rather than modifying them in order to make them the source of truth, this adds properties `resolved_config` & `resolved_targets` that all modifiers should instead use as source of truth. These are resolved once, when needed, and not serialized, which should fix the bug in #1906 TEST PLAN: Added a `test_resolved_targets` unit test --------- Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 6a71591 commit 7194d4b

File tree

5 files changed

+94
-31
lines changed

5 files changed

+94
-31
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def on_end(self, state: State, event: Event, **kwargs):
268268
self.ended_ = True
269269

270270
for _, module in tqdm(
271-
match_named_modules(state.model, self.targets, self.ignore),
271+
match_named_modules(state.model, self.resolved_targets, self.ignore),
272272
desc="Calibrating weights",
273273
):
274274
update_weight_zp_scale(module)

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ def on_initialize(self, state: State, **kwargs) -> bool:
162162
# prepare module names
163163
self._module_names = {
164164
m: name
165-
for name, m in match_named_modules(state.model, self.targets, self.ignore)
165+
for name, m in match_named_modules(
166+
state.model, self.resolved_targets, self.ignore
167+
)
166168
}
167169

168170
return True
@@ -176,7 +178,9 @@ def on_start(self, state: State, event: Event, **kwargs):
176178

177179
# register gptq hooks
178180
added_hook = False
179-
for _, module in match_named_modules(state.model, self.targets, self.ignore):
181+
for _, module in match_named_modules(
182+
state.model, self.resolved_targets, self.ignore
183+
):
180184
if getattr_chain(module, "quantization_scheme.weights", None) is not None:
181185
# HACK: previously, embeddings were not quantized because they were not
182186
# accessible by the layer compressor. For now, we manually ignore it,

src/llmcompressor/modifiers/quantization/quantization/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def on_start(self, state: State, event: Event, **kwargs):
7171
QuantizationMixin.start_calibration(self, state.model)
7272

7373
named_modules = list(
74-
match_named_modules(state.model, self.targets, self.ignore)
74+
match_named_modules(state.model, self.resolved_targets, self.ignore)
7575
)
7676
# TODO: this step can be combined with update_weight_zp_scale
7777
# once update_fused_layer_weight_global_scales is removed

src/llmcompressor/modifiers/quantization/quantization/mixin.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
preset_name_to_scheme,
1616
)
1717
from compressed_tensors.utils import match_named_modules
18-
from pydantic import Field, PrivateAttr, field_validator, model_validator
18+
from pydantic import Field, PrivateAttr, field_validator
1919
from torch.utils.hooks import RemovableHandle
2020

2121
from llmcompressor.modifiers.quantization.calibration import (
@@ -62,6 +62,9 @@ class QuantizationMixin(HooksMixin):
6262
:param targets: list of layer names to quantize if a scheme is provided. If unset,
6363
will contain all targets listed in config_groups. If config_groups is also
6464
unset, will default to ["Linear"] (i.e. all Linear layers will be targeted).
65+
This field is not the source of truth for finding all matching target layers
66+
in a model. Additional information can be stored in `config_groups`. Use
67+
self.resolved_targets instead.
6568
:param ignore: optional list of module class names or submodule names to not
6669
quantize even if they match a target in config_groups. Defaults to empty list.
6770
:param scheme: a single quantization scheme to apply to the model. This is a
@@ -83,12 +86,16 @@ class QuantizationMixin(HooksMixin):
8386
"""
8487

8588
config_groups: Optional[Dict[str, QuantizationScheme]] = None
86-
targets: Union[str, List[str]] = Field(default_factory=list)
89+
# NOTE: targets is not the sole source of truth for finding all matching target
90+
# layers in a model. Additional information can be stored in `config_groups`
91+
# Use self.resolved_targets as source of truth.
92+
targets: Union[str, List[str]] = Field(default_factory=lambda: ["Linear"])
8793
ignore: List[str] = Field(default_factory=list)
8894
scheme: Optional[Union[str, Dict[str, Any]]] = None
8995
kv_cache_scheme: Optional[QuantizationArgs] = None
9096

9197
_calibration_hooks: Set[RemovableHandle] = PrivateAttr(default_factory=set)
98+
_resolved_config: Optional[QuantizationConfig] = PrivateAttr(None)
9299

93100
@field_validator("targets", mode="before")
94101
def validate_targets(cls, value: Union[str, List[str]]) -> List[str]:
@@ -116,27 +123,29 @@ def validate_scheme(
116123

117124
return value
118125

119-
@model_validator(mode="after")
120-
def validate_model_after(model: "QuantizationMixin") -> "QuantizationMixin":
126+
@property
127+
def resolved_config(self) -> QuantizationConfig:
121128
"""
122-
- If targets have not been set, aggregate targets from config_groups
123-
into a single unique list
124-
- If targets have still not been found, default to targets=["Linear"]
129+
Quantization config needs to be resolved just once based on
130+
scheme and config_groups inputs.
125131
"""
132+
if self._resolved_config is None:
133+
self._resolved_config = self.resolve_quantization_config()
134+
return self._resolved_config
126135

127-
if len(model.targets) > 0 and model.config_groups is not None:
128-
raise ValueError("Please specify either `targets` or `config_groups`")
129-
130-
if len(model.targets) == 0 and model.config_groups is not None:
131-
for config_group in model.config_groups.values():
132-
for target in config_group.targets:
133-
if target not in model.targets:
134-
model.targets.append(target)
135-
136-
if len(model.targets) == 0:
137-
model.targets.append("Linear")
138-
139-
return model
136+
@property
137+
def resolved_targets(self) -> Set[str]:
138+
"""
139+
Set of all resolved targets, i.e. all unique targets listed
140+
in resolved quantization config.
141+
Use this property instead of the targets field, as targets can
142+
also come from config_groups depending on how recipe is configured.
143+
"""
144+
targets = set()
145+
for config_group in self.resolved_config.config_groups.values():
146+
for target in config_group.targets:
147+
targets.add(target)
148+
return targets
140149

141150
def initialize_quantization(self, model: torch.nn.Module):
142151
"""
@@ -145,13 +154,11 @@ def initialize_quantization(self, model: torch.nn.Module):
145154
146155
:param model: model to attach schemes and observers to
147156
"""
148-
# apply scheme and status to model
149-
config = self.resolve_quantization_config()
150157

151-
for _, module in match_named_modules(model, self.targets, self.ignore):
158+
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
152159
reset_quantization_status(module) # reset any previously applied qconfigs
153160

154-
apply_quantization_config(model, config)
161+
apply_quantization_config(model, self.resolved_config)
155162

156163
# disable quantization until calibration
157164
model.apply(disable_quantization)
@@ -164,7 +171,7 @@ def start_calibration(self, model: torch.nn.Module):
164171
:param model: model to prepare for calibration
165172
"""
166173
self._calibration_hooks = self._initialize_hooks(model)
167-
for _, module in match_named_modules(model, self.targets, self.ignore):
174+
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
168175
self._initialize_observers(module)
169176
apply_calibration_status(module)
170177

@@ -178,7 +185,7 @@ def end_calibration(self, model: torch.nn.Module):
178185
:param model: model to end calibration for
179186
"""
180187
self.remove_hooks(self._calibration_hooks)
181-
for _, module in match_named_modules(model, self.targets, self.ignore):
188+
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
182189
freeze_module_quantization(module) # remove observers
183190

184191
model.apply(enable_quantization) # keep quantization enabled
@@ -270,7 +277,7 @@ def _initialize_observers(self, module: torch.nn.Module):
270277

271278
def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]:
272279
hooks = set()
273-
for _, module in match_named_modules(model, self.targets, self.ignore):
280+
for _, module in match_named_modules(model, self.resolved_targets, self.ignore):
274281
if not hasattr(module, "quantization_scheme"):
275282
continue
276283

tests/llmcompressor/modifiers/quantization/test_base.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,55 @@ def test_serialize_actorder(has_actorder, actorder, exp_actorder):
159159
modifier = GPTQModifier(targets=["Linear"], scheme="W8A8")
160160

161161
assert modifier.model_dump()["actorder"] == exp_actorder
162+
163+
164+
@pytest.mark.parametrize(
165+
"scheme,targets,config_groups,resolved_targets,should_error",
166+
[
167+
("W4A16", ["Linear"], None, {"Linear"}, False),
168+
(
169+
"W4A16",
170+
[r"re:.*q_proj$", r"re:.*k_proj$"],
171+
None,
172+
{r"re:.*q_proj$", r"re:.*k_proj$"},
173+
False,
174+
),
175+
(
176+
None,
177+
["Linear"],
178+
dict(
179+
group_0=dict(
180+
targets=[r"re:.*q_proj$"],
181+
),
182+
group_1=dict(
183+
targets=[r"re:.*k_proj$"],
184+
),
185+
),
186+
{r"re:.*q_proj$", r"re:.*k_proj$"},
187+
False,
188+
),
189+
(
190+
"W4AA16",
191+
["Linear"],
192+
dict(
193+
group_0=dict(
194+
targets=[r"re:.*q_proj$"],
195+
),
196+
),
197+
{},
198+
True,
199+
),
200+
],
201+
)
202+
def test_resolved_targets(
203+
scheme, targets, config_groups, should_error, resolved_targets
204+
):
205+
if should_error:
206+
with pytest.raises(ValueError):
207+
GPTQModifier(targets=targets, scheme=scheme, config_groups=config_groups)
208+
else:
209+
modifier = GPTQModifier(
210+
targets=targets, scheme=scheme, config_groups=config_groups
211+
)
212+
213+
assert modifier.resolved_targets == resolved_targets

0 commit comments

Comments
 (0)