diff --git a/CHANGELOG.rst b/CHANGELOG.rst index beb01abf0..60c7ec5ca 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -28,6 +28,7 @@ Model Optimizer Changelog (Linux) **Misc** - Bump minimum recommended transformers version to 4.53. +- Replace ONNX simplification package from 'onnxsim' to 'onnxslim'. 0.39 (2025-11-11) ^^^^^^^^^^^^^^^^^ diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 800124646..00b4f5d75 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -39,6 +39,7 @@ from typing import Any import onnx +import onnxslim import onnx.onnx_cpp2py_export.checker as C import onnx_graphsurgeon as gs @@ -133,16 +134,8 @@ def _preprocess_onnx( if simplify: logger.info("Attempting to simplify model") try: - import onnxsim - except ModuleNotFoundError as e: - logger.warning( - "onnxsim is not installed. Please install it with 'pip install onnxsim'." - ) - raise e - - try: - model_simp, check = onnxsim.simplify(onnx_model) - if check: + model_simp = onnxslim.slim(onnx_model, skip_fusion_patterns=["FusionGemm"]) + if model_simp: onnx_model = model_simp onnx_path = os.path.join(output_dir, f"{model_name}_simp.onnx") save_onnx(onnx_model, onnx_path, use_external_data_format) diff --git a/setup.py b/setup.py index 85b79e729..7befe9a47 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,7 @@ "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 "onnxruntime-directml==1.20.0; platform_system == 'Windows'", "onnxscript", # For autocast opset conversion and test_onnx_dynamo_export unit test - "onnxsim ; python_version < '3.12' and platform_machine != 'aarch64'", + "onnxslim>=0.1.75, "polygraphy>=0.49.22", ], "hf": [ diff --git a/tests/gpu/onnx/test_simplify.py b/tests/gpu/onnx/test_simplify.py index 3b6acccb6..5ca8449b3 100644 --- a/tests/gpu/onnx/test_simplify.py +++ b/tests/gpu/onnx/test_simplify.py @@ -57,14 +57,14 @@ def test_onnx_simplification(tmp_path): assert os.path.isfile(output_onnx_path), "Quantized ONNX was not found!" # Load the simplified model and check that the model doesn't contain Identity nodes, - # only 3 layers (Conv->BN->Relu). + # only 2 layers (Conv->Relu). graph = gs.import_onnx(onnx.load(simplified_onnx_path)) identity_nodes = [n for n in graph.nodes if n.op == "Identity"] assert not identity_nodes, "Simplified ONNX model contains Identity nodes but it shouldn't." - assert len(graph.nodes) == 3, ( - f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 3." + assert len(graph.nodes) == 2, ( + f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 2." ) - assert all(n.op in ["Conv", "BatchNormalization", "Relu"] for n in graph.nodes), ( + assert all(n.op in ["Conv", "Relu"] for n in graph.nodes), ( "Graph contains more ops than expected." )