diff --git a/docs/source/features/onnx-transformations.md b/docs/source/features/onnx-transformations.md index 9dfcd1c67c..df1079b390 100644 --- a/docs/source/features/onnx-transformations.md +++ b/docs/source/features/onnx-transformations.md @@ -2100,6 +2100,91 @@ Two cases are supported: ``` +### `RenameOutputDims` + +#### Description + +Renames a dimension in an output tensor's shape. Useful for restoring meaningful symbolic dimension names after graph transformations that may have changed them (e.g. after `OrtTransformersOptimization`). + +#### Configurations + +- `output_idx`: Index of the output tensor to modify. +- `dim_idx`: Index of the dimension within the output's shape. +- `dim_name`: New symbolic name for the dimension. + +#### Example + +```json +{ + "type": "GraphSurgeries", + "surgeries": [ + { + "surgeon": "RenameOutputDims", + "output_idx": 0, + "dim_idx": 0, + "dim_name": "num_logical_patches" + } + ] +} +``` + +### `RenameInputDims` + +#### Description + +Renames or promotes a dimension in an input tensor's shape to a named symbolic dimension. Useful when `torch.export` specializes a batch-like input dimension to a concrete value but ONNX Runtime needs to accept a variable-length tensor at inference time. The target input can be specified by name (preferred) or by index. + +#### Configurations + +- `dim_idx`: Index of the dimension within the input's shape. +- `dim_name`: New symbolic name for the dimension. +- `input_name` *(optional)*: Name of the input tensor to modify. +- `input_idx` *(optional)*: Index of the input tensor to modify. Either `input_name` or `input_idx` must be provided. + +#### Example + +```json +{ + "type": "GraphSurgeries", + "surgeries": [ + { + "surgeon": "RenameInputDims", + "input_name": "image_grid_thw", + "dim_idx": 0, + "dim_name": "num_images" + } + ] +} +``` + +### `RemoveMemcpy` + +#### Description + +Removes `MemcpyToHost` and `MemcpyFromHost` nodes that are inserted by ORT's `OrtTransformersOptimization` pass when it pre-partitions a graph for a GPU execution provider. These nodes represent explicit GPU↔CPU data copies for tensors whose consumers require CPU memory (e.g. shape arguments to `Reshape`, start/end for `Slice`, trip counts for `Loop`). + +Removing them is safe because ORT's runtime `MemcpyTransformer` will re-insert only the truly necessary copies when the session is created. The runtime also has a `GetCpuPreferredNodes` heuristic that may keep entire shape-computation subgraphs on CPU, potentially avoiding some copies entirely. + +The surgery processes both the main graph and all Loop/If subgraphs recursively. After removal the graph nodes are topologically re-sorted to satisfy the ONNX requirement that every input is produced before use. + +#### Configurations + +No parameters required. + +#### Example + +```json +{ + "type": "GraphSurgeries", + "surgeries": [ + { + "surgeon": "RemoveMemcpy" + } + ] +} +``` + + ## ORT Performance Tuning ONNX Runtime provides high performance across a range of hardware options through its Execution Providers interface for different execution diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index e82f83ff18..533c0c2c5c 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -2523,6 +2523,254 @@ def call_ir(self, model: ir.Model) -> ir.Model: return model +class RemoveMemcpy(ProtoSurgeon): + """Remove MemcpyToHost and MemcpyFromHost nodes from the graph. + + These nodes are inserted by ORT's ``OrtTransformersOptimization`` when it + pre-partitions the graph for a GPU execution provider. They represent + explicit GPU↔CPU data copies for tensors whose consumers require CPU memory + (e.g. shape arguments to Reshape, start/end for Slice, trip counts for Loop). + + Removing them is safe because ORT's runtime ``MemcpyTransformer`` will + re-insert only the truly necessary copies when the session is created. + The runtime also has a ``GetCpuPreferredNodes`` heuristic that may keep + entire shape-computation subgraphs on CPU, potentially avoiding some + copies entirely. + + This surgery processes both the main graph and all Loop/If subgraphs + recursively. After removal the graph nodes are topologically re-sorted + to satisfy the ONNX requirement that every input is produced before use. + + When to use: + Run **after** ``OrtTransformersOptimization`` to remove pre-baked memcpy + nodes and let ORT's runtime re-partition optimally. + """ + + def __call__(self, model: ModelProto): + total = self._remove_from_graph(model.graph) + if total: + logger.debug("Removed %d Memcpy nodes total", total) + return model + + @staticmethod + def _remove_from_graph(graph) -> int: + """Remove MemcpyToHost/MemcpyFromHost from a graph, then topo-sort.""" + removed = 0 + + # Build output→input bypass mapping for 1-in/1-out Memcpy nodes only + bypass: dict[str, str] = {} + for node in graph.node: + if node.op_type in ("MemcpyToHost", "MemcpyFromHost") and len(node.input) == 1 and len(node.output) == 1: + bypass[node.output[0]] = node.input[0] + + if bypass: + # Resolve chained Memcpy transitively: if A→B→C are both Memcpy, + # bypass = {B: A, C: B}. Follow the chain so C maps to A. + for key, value in bypass.items(): + target = value + while target in bypass: + target = bypass[target] + bypass[key] = target + + # Rewrite consumer references: replace memcpy output with its input + for node in graph.node: + if node.op_type in ("MemcpyToHost", "MemcpyFromHost"): + continue + for i, inp in enumerate(node.input): + if inp in bypass: + node.input[i] = bypass[inp] + # Also rewrite inputs inside Loop/If subgraph body references + for attr in node.attribute: + if attr.g: + RemoveMemcpy._rewrite_subgraph_refs(attr.g, bypass) + + # Preserve graph output names: if a Memcpy sits on the output + # boundary, rename the upstream producer's output to match the + # original graph output name instead of changing the public name. + for out in graph.output: + if out.name in bypass: + src = bypass[out.name] + # Rename the producer node's output to keep the public name + for node in graph.node: + for j, o in enumerate(node.output): + if o == src: + node.output[j] = out.name + # Also update any other consumers of `src` to use the output name + for node in graph.node: + for j, inp in enumerate(node.input): + if inp == src: + node.input[j] = out.name + + # Remove only 1-in/1-out Memcpy nodes (the ones we built bypass for) + indices = [ + i + for i, n in enumerate(graph.node) + if n.op_type in ("MemcpyToHost", "MemcpyFromHost") and len(n.input) == 1 and len(n.output) == 1 + ] + for i in reversed(indices): + del graph.node[i] + removed += len(indices) + + # Topological re-sort to fix ordering after node removal + RemoveMemcpy._topo_sort(graph) + + # Recurse into Loop/If subgraphs + for node in list(graph.node): + for attr in node.attribute: + if attr.g: + removed += RemoveMemcpy._remove_from_graph(attr.g) + + return removed + + @staticmethod + def _rewrite_subgraph_refs(subgraph, bypass: dict[str, str]): + """Rewrite implicit references inside a subgraph body. + + Loop/If subgraphs can reference outer-scope tensors by name in their + node inputs. If an outer Memcpy was removed, those references must + be updated too. + """ + for node in subgraph.node: + for i, inp in enumerate(node.input): + if inp in bypass: + node.input[i] = bypass[inp] + for attr in node.attribute: + if attr.g: + RemoveMemcpy._rewrite_subgraph_refs(attr.g, bypass) + + @staticmethod + def _topo_sort(graph): + """Topologically sort graph.node in place using Kahn's algorithm.""" + # Collect all tensor names produced by graph inputs + initializers + available: set[str] = set() + for inp in graph.input: + available.add(inp.name) + for init in graph.initializer: + available.add(init.name) + + # Build producer map: tensor_name → node_index + nodes = list(graph.node) + node_outputs: list[set[str]] = [{o for o in n.output if o} for n in nodes] + + # Build adjacency: which node indices each node depends on + n = len(nodes) + in_degree = [0] * n + dependents: list[list[int]] = [[] for _ in range(n)] + + # Map output name → producing node index + output_to_idx: dict[str, int] = {} + for idx, outs in enumerate(node_outputs): + for o in outs: + output_to_idx[o] = idx + + for idx, node in enumerate(nodes): + seen_deps: set[int] = set() + for inp in node.input: + if inp and inp not in available and inp in output_to_idx: + dep = output_to_idx[inp] + if dep != idx and dep not in seen_deps: + seen_deps.add(dep) + in_degree[idx] += 1 + dependents[dep].append(idx) + + # Kahn's algorithm + from collections import deque + + queue: deque[int] = deque() + for idx in range(n): + if in_degree[idx] == 0: + queue.append(idx) + + sorted_indices: list[int] = [] + while queue: + idx = queue.popleft() + sorted_indices.append(idx) + # Mark outputs as available + for o in node_outputs[idx]: + available.add(o) + for dep_idx in dependents[idx]: + in_degree[dep_idx] -= 1 + if in_degree[dep_idx] == 0: + queue.append(dep_idx) + + if len(sorted_indices) != n: + logger.warning( + "Topo-sort could not order all nodes (%d/%d). Keeping original order for unresolved nodes.", + len(sorted_indices), + n, + ) + # Append any remaining nodes in original order + remaining = set(range(n)) - set(sorted_indices) + sorted_indices.extend(sorted(remaining)) + + # Rewrite graph.node in sorted order + sorted_nodes = [nodes[i] for i in sorted_indices] + del graph.node[:] + graph.node.extend(sorted_nodes) + + +class RenameInputDims(Surgeon): + """Rename / promote a dimension in an input tensor's shape to a named symbolic dim. + + This surgery replaces a concrete dim_value (e.g. ``1``) with a symbolic + dim_param string (e.g. ``"num_images"``). Useful when torch.export + specialises a batch-like input dimension to a concrete value because its + shape is algebraically derived from another symbolic dimension, yet ONNX + Runtime must accept a variable-length tensor at inference time. + + Specify the target input either by name (preferred) or by index. + + Example usage: + { + "surgeon": "RenameInputDims", + "input_name": "image_grid_thw", + "dim_idx": 0, + "dim_name": "num_images" + } + """ + + def __init__( + self, + dim_idx: int, + dim_name: str, + input_name: str | None = None, + input_idx: int | None = None, + ): + super().__init__() + if input_name is None and input_idx is None: + raise ValueError("Either 'input_name' or 'input_idx' must be provided.") + self.input_name = input_name + self.input_idx = input_idx + self.dim_idx = dim_idx + self.dim_name = dim_name + + def call_ir(self, model: ir.Model) -> ir.Model: + inputs = list(model.graph.inputs) + + if self.input_name is not None: + target = next((v for v in inputs if v.name == self.input_name), None) + if target is None: + available = [v.name for v in inputs] + raise ValueError(f"Input '{self.input_name}' not found in graph. Available inputs: {available}") + else: + if self.input_idx >= len(inputs): + raise ValueError(f"input_idx {self.input_idx} is out of range. Model has {len(inputs)} inputs.") + target = inputs[self.input_idx] + + if target.shape is None: + raise ValueError(f"Input '{target.name}' has no shape information; cannot rename dimensions.") + + if self.dim_idx >= len(target.shape): + raise ValueError( + f"dim_idx {self.dim_idx} is out of range. Input '{target.name}' has {len(target.shape)} dimensions." + ) + + new_dims = list(target.shape) + new_dims[self.dim_idx] = self.dim_name + target.shape = ir.Shape(new_dims) + return model + + class GraphSurgeries(Pass): """ONNX graph surgeries collections. diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index a32a060046..337ee9139b 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -2561,3 +2561,343 @@ def test_deduplicate_nodes(tmp_path): x = np.random.randn(2, 4).astype(np.float32) result = sess.run(None, {"X": x})[0] np.testing.assert_allclose(result, x.astype(np.float16), atol=1e-3) + + +def test_rename_input_dims(tmp_path): + """Test that RenameInputDims renames a dimension in an input tensor's shape.""" + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 4]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 4]) + + node = helper.make_node("Identity", inputs=["input"], outputs=["output"], name="Identity") + + graph = helper.make_graph( + nodes=[node], + name="TestGraph", + inputs=[input_tensor], + outputs=[output_tensor], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "onnx") + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "RenameInputDims", "input_name": "input", "dim_idx": 0, "dim_name": "batch"}]}, + disable_search=True, + ) + + output_model = p.run(input_model, output_folder) + output_model_def = output_model.load_model() + + input_shape = output_model_def.graph.input[0].type.tensor_type.shape + dim_names = [dim.dim_param if dim.dim_param else str(dim.dim_value) for dim in input_shape.dim] + assert dim_names[0] == "batch" + # Other dims should be unchanged + assert input_shape.dim[1].dim_value == 3 + assert input_shape.dim[2].dim_value == 4 + + +def test_rename_input_dims_by_index(tmp_path): + """Test that RenameInputDims works with input_idx instead of input_name.""" + input_tensor = helper.make_tensor_value_info("pixel_values", TensorProto.FLOAT, [10, 1176]) + input2_tensor = helper.make_tensor_value_info("image_grid_thw", TensorProto.INT64, [1, 3]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [10, 1176]) + + node = helper.make_node("Identity", inputs=["pixel_values"], outputs=["output"], name="Identity") + + graph = helper.make_graph( + nodes=[node], + name="TestGraph", + inputs=[input_tensor, input2_tensor], + outputs=[output_tensor], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "onnx") + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "RenameInputDims", "input_idx": 1, "dim_idx": 0, "dim_name": "num_images"}]}, + disable_search=True, + ) + + output_model = p.run(input_model, output_folder) + output_model_def = output_model.load_model() + + input_shape = output_model_def.graph.input[1].type.tensor_type.shape + assert input_shape.dim[0].dim_param == "num_images" + assert input_shape.dim[1].dim_value == 3 + + +def test_rename_input_dims_invalid_input_name(tmp_path): + """Test that RenameInputDims raises ValueError for non-existent input name.""" + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3]) + + node = helper.make_node("Identity", inputs=["input"], outputs=["output"]) + + graph = helper.make_graph( + nodes=[node], + name="TestGraph", + inputs=[input_tensor], + outputs=[output_tensor], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "onnx") + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "RenameInputDims", "input_name": "nonexistent", "dim_idx": 0, "dim_name": "x"}]}, + disable_search=True, + ) + + with pytest.raises(ValueError, match="not found in graph"): + p.run(input_model, output_folder) + + +def test_remove_memcpy_main_graph(tmp_path): + """Test that RemoveMemcpy removes MemcpyToHost/MemcpyFromHost nodes from the main graph.""" + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 8]) + + # Build: input -> MemcpyToHost -> Relu -> MemcpyFromHost -> output + memcpy_to = helper.make_node("MemcpyToHost", inputs=["input"], outputs=["cpu_input"], name="MemcpyTo") + relu = helper.make_node("Relu", inputs=["cpu_input"], outputs=["relu_out"], name="Relu") + memcpy_from = helper.make_node("MemcpyFromHost", inputs=["relu_out"], outputs=["output"], name="MemcpyFrom") + + graph = helper.make_graph( + nodes=[memcpy_to, relu, memcpy_from], + name="TestGraph", + inputs=[input_tensor], + outputs=[output_tensor], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "onnx") + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "RemoveMemcpy"}]}, + disable_search=True, + ) + + output_model = p.run(input_model, output_folder) + output_model_def = output_model.load_model() + + op_types = [n.op_type for n in output_model_def.graph.node] + assert "MemcpyToHost" not in op_types + assert "MemcpyFromHost" not in op_types + assert "Relu" in op_types + assert len(output_model_def.graph.node) == 1 + + # Public output name should be preserved + assert output_model_def.graph.output[0].name == "output" + + # Verify inference still works and the session exposes the same output name + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + sess_outputs = sess.get_outputs() + assert sess_outputs[0].name == "output" + + x = np.random.randn(4, 8).astype(np.float32) + result = sess.run(["output"], {"input": x})[0] + np.testing.assert_allclose(result, np.maximum(x, 0), atol=1e-6) + + +def test_remove_memcpy_loop_subgraph(tmp_path): + """Test that RemoveMemcpy removes MemcpyToHost from Loop subgraphs.""" + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 8]) + + # Build a Loop subgraph that contains a MemcpyToHost + # Loop body: (iter, cond, carry_in) -> (cond_out, carry_out) + body_iter = helper.make_tensor_value_info("iter", TensorProto.INT64, []) + body_cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + body_carry_in = helper.make_tensor_value_info("carry_in", TensorProto.FLOAT, [4, 8]) + body_cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + body_carry_out = helper.make_tensor_value_info("carry_out", TensorProto.FLOAT, [4, 8]) + + body_memcpy = helper.make_node("MemcpyToHost", inputs=["carry_in"], outputs=["carry_cpu"], name="BodyMemcpy") + body_relu = helper.make_node("Relu", inputs=["carry_cpu"], outputs=["carry_out"], name="BodyRelu") + body_cond = helper.make_node("Identity", inputs=["cond_in"], outputs=["cond_out"], name="BodyCond") + + body_graph = helper.make_graph( + nodes=[body_memcpy, body_relu, body_cond], + name="LoopBody", + inputs=[body_iter, body_cond_in, body_carry_in], + outputs=[body_cond_out, body_carry_out], + ) + + # Main graph: input -> Loop(1 iteration) -> output + trip_count = helper.make_tensor("trip", TensorProto.INT64, [], [1]) + cond_init = helper.make_tensor("cond_init", TensorProto.BOOL, [], [True]) + loop_node = helper.make_node( + "Loop", + inputs=["trip", "cond_init", "input"], + outputs=["output"], + name="Loop", + body=body_graph, + ) + + graph = helper.make_graph( + nodes=[loop_node], + name="TestGraph", + inputs=[input_tensor], + outputs=[output_tensor], + initializer=[trip_count, cond_init], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "onnx") + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "RemoveMemcpy"}]}, + disable_search=True, + ) + + output_model = p.run(input_model, output_folder) + output_model_def = output_model.load_model() + + # Check Loop subgraph has no MemcpyToHost + loop_node = output_model_def.graph.node[0] + assert loop_node.op_type == "Loop" + body = None + for attr in loop_node.attribute: + if attr.name == "body": + body = attr.g + assert body is not None + sub_ops = [n.op_type for n in body.node] + assert "MemcpyToHost" not in sub_ops + assert "Relu" in sub_ops + + +def test_remove_memcpy_topo_sort(tmp_path): + """Test that RemoveMemcpy properly re-sorts nodes topologically after removal.""" + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 8]) + + # Build: input -> Relu -> MemcpyToHost -> Sigmoid -> output + # The MemcpyToHost renames relu_out -> cpu_relu, so after removal + # Sigmoid must reference relu_out directly and appear after Relu. + relu = helper.make_node("Relu", inputs=["input"], outputs=["relu_out"], name="Relu") + memcpy = helper.make_node("MemcpyToHost", inputs=["relu_out"], outputs=["cpu_relu"], name="Memcpy") + sigmoid = helper.make_node("Sigmoid", inputs=["cpu_relu"], outputs=["output"], name="Sigmoid") + + graph = helper.make_graph( + nodes=[relu, memcpy, sigmoid], + name="TestGraph", + inputs=[input_tensor], + outputs=[output_tensor], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "onnx") + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "RemoveMemcpy"}]}, + disable_search=True, + ) + + output_model = p.run(input_model, output_folder) + output_model_def = output_model.load_model() + + op_types = [n.op_type for n in output_model_def.graph.node] + assert op_types == ["Relu", "Sigmoid"] + + # Verify Sigmoid now reads from relu_out (bypassing removed memcpy) + sigmoid_node = output_model_def.graph.node[1] + assert sigmoid_node.input[0] == "relu_out" + + # Verify inference + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + x = np.random.randn(4, 8).astype(np.float32) + result = sess.run(None, {"input": x})[0] + expected = 1.0 / (1.0 + np.exp(-np.maximum(x, 0))) + np.testing.assert_allclose(result, expected, atol=1e-6) + + +def test_remove_memcpy_chained(tmp_path): + """Test that RemoveMemcpy handles chained Memcpy nodes and preserves output names.""" + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8]) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 8]) + + # Build: input -> MemcpyToHost -> MemcpyToHost -> Relu -> MemcpyFromHost -> MemcpyFromHost -> output + memcpy_to_1 = helper.make_node("MemcpyToHost", inputs=["input"], outputs=["cpu1"], name="MemcpyTo1") + memcpy_to_2 = helper.make_node("MemcpyToHost", inputs=["cpu1"], outputs=["cpu2"], name="MemcpyTo2") + relu = helper.make_node("Relu", inputs=["cpu2"], outputs=["relu_out"], name="Relu") + memcpy_from_1 = helper.make_node("MemcpyFromHost", inputs=["relu_out"], outputs=["gpu1"], name="MemcpyFrom1") + memcpy_from_2 = helper.make_node("MemcpyFromHost", inputs=["gpu1"], outputs=["output"], name="MemcpyFrom2") + + graph = helper.make_graph( + nodes=[memcpy_to_1, memcpy_to_2, relu, memcpy_from_1, memcpy_from_2], + name="TestGraph", + inputs=[input_tensor], + outputs=[output_tensor], + ) + + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "onnx") + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "RemoveMemcpy"}]}, + disable_search=True, + ) + + output_model = p.run(input_model, output_folder) + output_model_def = output_model.load_model() + + op_types = [n.op_type for n in output_model_def.graph.node] + assert "MemcpyToHost" not in op_types + assert "MemcpyFromHost" not in op_types + assert "Relu" in op_types + assert len(output_model_def.graph.node) == 1 + + # Public output name must be preserved + assert output_model_def.graph.output[0].name == "output" + + # Verify inference with preserved output name + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + sess_outputs = sess.get_outputs() + assert sess_outputs[0].name == "output" + + x = np.random.randn(4, 8).astype(np.float32) + result = sess.run(["output"], {"input": x})[0] + np.testing.assert_allclose(result, np.maximum(x, 0), atol=1e-6)