diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 730969ba9c..163f1ae5c3 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -2826,6 +2826,78 @@ def check_nn_module(node): if node.name == "mul": check_nn_module(node) + def test_quantize_in_place_ops(self): + class TestQuantizer(Quantizer): + example_inputs = None + + def set_example_inputs(self, example_inputs): + self.example_inputs = example_inputs + + def transform_for_annotation( + self, model: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + # Make a copy of the graph to ensure that we are using the + # return value of this function. + ep = torch.export.export(model, self.example_inputs) + ep = ep.run_decompositions({}) + return ep.module() + + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: + act_qspec = QuantizationSpec( + dtype=torch.uint8, + quant_min=0, + quant_max=255, + qscheme=torch.per_tensor_affine, + is_dynamic=False, + observer_or_fake_quant_ctr=observer.default_observer + ) + for node in model.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.add.Tensor + ): + input_act0 = node.args[0] + assert isinstance(input_act0, torch.fx.Node) + input_act1 = node.args[1] + if isinstance(input_act1, torch.fx.Node): + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act0: act_qspec, + input_act1: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + else: + # Handle case where second input is a constant + node.meta["quantization_annotation"] = QuantizationAnnotation( + input_qspec_map={ + input_act0: act_qspec, + }, + output_qspec=act_qspec, + _annotated=True, + ) + + def validate(self, model: torch.fx.GraphModule) -> None: + pass + + class M(torch.nn.Module): + def forward(self, x): + return x + 3 + + m = M().eval() + quantizer = TestQuantizer() + example_inputs = (torch.randn(1, 2, 3, 3),) + quantizer.set_example_inputs(example_inputs) + m = export_for_training(m, example_inputs, strict=True).module() + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m) + + # Verify the quantized model works + result = m(*example_inputs) + self.assertIsNotNone(result) + @skipIfNoQNNPACK @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_7, "Requires torch 2.7+")