Skip to content

Add inplace quantizer examples #2345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -2826,6 +2826,78 @@ def check_nn_module(node):
if node.name == "mul":
check_nn_module(node)

def test_quantize_in_place_ops(self):
class TestQuantizer(Quantizer):
example_inputs = None

def set_example_inputs(self, example_inputs):
self.example_inputs = example_inputs

def transform_for_annotation(
self, model: torch.fx.GraphModule
) -> torch.fx.GraphModule:
# Make a copy of the graph to ensure that we are using the
# return value of this function.
ep = torch.export.export(model, self.example_inputs)
ep = ep.run_decompositions({})
return ep.module()

def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_observer
)
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.add.Tensor
):
input_act0 = node.args[0]
assert isinstance(input_act0, torch.fx.Node)
input_act1 = node.args[1]
if isinstance(input_act1, torch.fx.Node):
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act0: act_qspec,
input_act1: act_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)
else:
# Handle case where second input is a constant
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act0: act_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)

def validate(self, model: torch.fx.GraphModule) -> None:
pass

class M(torch.nn.Module):
def forward(self, x):
return x + 3

m = M().eval()
quantizer = TestQuantizer()
example_inputs = (torch.randn(1, 2, 3, 3),)
quantizer.set_example_inputs(example_inputs)
m = export_for_training(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)

# Verify the quantized model works
result = m(*example_inputs)
self.assertIsNotNone(result)


@skipIfNoQNNPACK
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")
Expand Down
Loading