diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 148e4c06a8051..0dac146f9eb59 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -1071,7 +1071,7 @@ def quantize_weight_per_channel( scale_name, zp_name, QuantizedValueType.Initializer, - None, + channel_axis, ) self.quantized_value_map[weight_name] = quantized_value @@ -1096,8 +1096,9 @@ def _dequantize_value(self, value_name): if self.model.model.producer_name != "onnx-quantizer" or ( self.model.model.producer_name == "onnx-quantizer" and scale_init is not None ): - # axis is not specified so scale_init must be a scalar. - assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1 + # Per-tensor (axis=None) requires a scalar scale. + if quantized_value.axis is None: + assert scale_init is None or onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) @@ -1108,7 +1109,11 @@ def _dequantize_value(self, value_name): quantized_value.zp_name, ] dequantize_node = onnx.helper.make_node( - "DequantizeLinear", dqlinear_inputs, [value_name], dqlinear_name + "DequantizeLinear", + dqlinear_inputs, + [value_name], + dqlinear_name, + axis=quantized_value.axis, ) return dequantize_node else: diff --git a/onnxruntime/test/python/quantization/test_quant_issues.py b/onnxruntime/test/python/quantization/test_quant_issues.py index 7b4e78faf3dd7..5f96b59ab21b4 100644 --- a/onnxruntime/test/python/quantization/test_quant_issues.py +++ b/onnxruntime/test/python/quantization/test_quant_issues.py @@ -67,6 +67,95 @@ def get_next(self): assert os.path.exists(preprocessed_path), f"missing output {preprocessed_path!r}" assert os.path.exists(quantized_path), f"missing output {quantized_path!r}" + def test_dynamic_quantize_per_channel_emits_axis_attribute(self): + """Per-channel dynamic quantization must emit axis on DequantizeLinear nodes. + + Regression test for https://github.com/microsoft/onnxruntime/issues/19997. + `quantize_dynamic(per_channel=True)` previously constructed QuantizedValue + with axis=None and built DequantizeLinear without an axis attribute, producing + an invalid per-tensor dequantization for per-channel quantized weights. + When the per-channel quantized weight also appears as a graph output, + _dequantize_outputs calls _dequantize_value, which triggered an assertion + error (scale not scalar) and would have emitted a DequantizeLinear lacking + the required axis attribute. + """ + try: + import numpy as np # noqa: PLC0415 + import onnx # noqa: PLC0415 + from onnx import TensorProto, helper, numpy_helper # noqa: PLC0415 + + from onnxruntime.quantization import QuantType, quantize_dynamic # noqa: PLC0415 + except ImportError as exc: + raise unittest.SkipTest(f"Required import missing: {exc}") from exc + + # Build a model: input (5, 4) @ weight (4, 8) -> output (5, 8). + # The weight is also passed through Identity and exposed as a second graph + # output so that _dequantize_outputs calls _dequantize_value on the + # per-channel-quantized weight initializer. + # Weight axis=1 is the output-feature axis (per-channel quantization target). + np.random.seed(42) + weight_data = np.random.normal(0, 0.1, (4, 8)).astype(np.float32) + weight_init = numpy_helper.from_array(weight_data, name="weight") + + input_vi = helper.make_tensor_value_info("input", TensorProto.FLOAT, [5, 4]) + output_vi = helper.make_tensor_value_info("output", TensorProto.FLOAT, [5, 8]) + weight_out_vi = helper.make_tensor_value_info("weight_out", TensorProto.FLOAT, [4, 8]) + + matmul_node = helper.make_node("MatMul", ["input", "weight"], ["output"]) + identity_node = helper.make_node("Identity", ["weight"], ["weight_out"]) + + graph = helper.make_graph( + [matmul_node, identity_node], + "test_graph", + [input_vi], + [output_vi, weight_out_vi], + [weight_init], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 8 + + with tempfile.TemporaryDirectory() as tmp: + model_fp_path = os.path.join(tmp, "model_fp.onnx") + model_q_path = os.path.join(tmp, "model_q.onnx") + onnx.save(model, model_fp_path) + + # This must not raise AssertionError due to per-channel scale not being scalar. + quantize_dynamic( + model_fp_path, + model_q_path, + per_channel=True, + weight_type=QuantType.QInt8, + ) + + q_model = onnx.load(model_q_path) + + # Find the DequantizeLinear node that dequantizes the weight initializer. + init_names = {init.name for init in q_model.graph.initializer} + dq_nodes = [n for n in q_model.graph.node if n.op_type == "DequantizeLinear"] + self.assertGreater(len(dq_nodes), 0, "Expected at least one DequantizeLinear node") + + weight_dq = None + for node in dq_nodes: + if node.input[0] in init_names: + weight_dq = node + break + self.assertIsNotNone(weight_dq, "No DequantizeLinear node found with a weight initializer as input") + + # The axis attribute must be present. + # MatMulInteger passes axis=-1 (last dimension) to quantize_weight_per_channel. + axis_attrs = [attr for attr in weight_dq.attribute if attr.name == "axis"] + self.assertEqual(len(axis_attrs), 1, "DequantizeLinear node is missing the 'axis' attribute") + # MatMulInteger quantizes weight with axis=-1 (default in __quantize_inputs). + self.assertEqual(axis_attrs[0].i, -1, f"Expected axis=-1, got axis={axis_attrs[0].i}") + + # The scale initializer must be 1-D with size > 1 (truly per-channel, not collapsed). + scale_name = weight_dq.input[1] + scale_init = next((i for i in q_model.graph.initializer if i.name == scale_name), None) + self.assertIsNotNone(scale_init, f"Scale initializer '{scale_name}' not found") + scale_array = numpy_helper.to_array(scale_init) + self.assertEqual(scale_array.ndim, 1, f"Expected 1-D scale, got shape {scale_array.shape}") + self.assertGreater(scale_array.size, 1, "Scale has only one element; expected per-channel scale") + if __name__ == "__main__": unittest.main(verbosity=2)