From 84a3aa044b4867a6c063dda6fce540201061e5aa Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:44:52 -0500 Subject: [PATCH 1/3] Update value info in graph Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 38820479c..9b238f728 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -955,6 +955,13 @@ def _cleanup(self): # Remove redundant casts self._remove_redundant_casts() + # Update value_info in model's graph + self._update_value_info_in_graph() + + def _update_value_info_in_graph(self): + for vi in self.model.graph.value_info: + vi.type.tensor_type.elem_type = self.value_info_map[vi.name].type.tensor_type.elem_type + def _cleanup_no_consumer_nodes(self): network_outputs = {o.name for o in self.model.graph.output} nodes_to_remove = [ From aa936315498fb597c4f72820c50ea5ac3989c788 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 3 Dec 2025 14:51:43 -0500 Subject: [PATCH 2/3] Ensure correct output tensor type in _fix_network_output_names Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- modelopt/onnx/autocast/precisionconverter.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 9b238f728..4f380e1d2 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -955,13 +955,6 @@ def _cleanup(self): # Remove redundant casts self._remove_redundant_casts() - # Update value_info in model's graph - self._update_value_info_in_graph() - - def _update_value_info_in_graph(self): - for vi in self.model.graph.value_info: - vi.type.tensor_type.elem_type = self.value_info_map[vi.name].type.tensor_type.elem_type - def _cleanup_no_consumer_nodes(self): network_outputs = {o.name for o in self.model.graph.output} nodes_to_remove = [ @@ -1121,6 +1114,13 @@ def _fix_network_output_names(self): break cast_node.input[0] = pre_cast_name cast_node.output[0] = original_name + # Ensure correct output tensor type + cast_to_precision = next( + attr.i for attr in cast_node.attribute if attr.name == "to" + ) + self.value_info_map[ + cast_node.output[0] + ].type.tensor_type.elem_type = cast_to_precision modified = True logger.debug(f"Fixed network output names: {post_cast_name} -> {output.name}") From 815ebf7f1d3a00179cf2fa515f767fe79667a722 Mon Sep 17 00:00:00 2001 From: gcunhase <4861122+gcunhase@users.noreply.github.com> Date: Wed, 3 Dec 2025 15:23:22 -0500 Subject: [PATCH 3/3] Add unittest Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com> --- tests/unit/onnx/test_autocast_quantize.py | 70 +++++++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 tests/unit/onnx/test_autocast_quantize.py diff --git a/tests/unit/onnx/test_autocast_quantize.py b/tests/unit/onnx/test_autocast_quantize.py new file mode 100644 index 000000000..b85f4a41a --- /dev/null +++ b/tests/unit/onnx/test_autocast_quantize.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import onnx +import onnx_graphsurgeon as gs +import pytest +import torch +from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx + +from modelopt.onnx.autocast import convert_to_mixed_precision +from modelopt.onnx.quantization import quantize + + +def assert_nodes_are_quantized(nodes): + for node in nodes: + for inp_idx, inp in enumerate(node.inputs): + if isinstance(inp, gs.Variable): + assert node.i(inp_idx).op == "DequantizeLinear", ( + f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" + ) + return True + + +@pytest.mark.parametrize("keep_io_types", [True, False]) +def test_autocast_quantize_int8(tmp_path, keep_io_types): + model_torch = SimpleMLP() + input_tensor = torch.randn(2, 16, 16) + low_precision_type = "fp16" + + onnx_path = os.path.join(tmp_path, "model.onnx") + export_as_onnx(model_torch, input_tensor, onnx_filename=onnx_path) + + # Convert model to low precision + converted_model = convert_to_mixed_precision( + onnx_path, keep_io_types=keep_io_types, low_precision_type=low_precision_type + ) + converted_model_path = onnx_path.replace( + ".onnx", f".{low_precision_type}.{'keepIOTypes' if keep_io_types else ''}.onnx" + ) + onnx.save(converted_model, converted_model_path) + + # Quantize converted model + quantize(converted_model_path, quantize_mode="int8", high_precision_dtype=low_precision_type) + + # Output model should be produced in the same tmp_path + output_onnx_path = converted_model_path.replace(".onnx", ".quant.onnx") + + # Check that quantized explicit model is generated + assert os.path.isfile(output_onnx_path) + + # Load the output model and check QDQ node placements + graph = gs.import_onnx(onnx.load(output_onnx_path)) + + # Check that all MatMul nodes are quantized + mm_nodes = [n for n in graph.nodes if n.op == "MatMul"] + assert assert_nodes_are_quantized(mm_nodes)