diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9fc33c2d0d054..ad33007867104 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -825,7 +825,7 @@ def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type ) if dst_tensor_type.elem_type != src_tensor_type.elem_type: - node_id = node.name if node.name else node.op_type + node_id = node.name or node.op_type raise ValueError( f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " @@ -1436,7 +1436,7 @@ def _infer_aten_multinomial(self, node): assert rank in [1, 2] num_samples = self._try_get_value(node, 1) di = rank - 1 - last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di)) + last_dim = num_samples or str(self._new_symbolic_dim_from_output(node, 0, di)) output_shape = [*sympy_shape[:-1], last_dim] vi = self.known_vi_[node.output[0]] vi.CopyFrom( @@ -1670,6 +1670,10 @@ def _infer_RelativePositionBias(self, node): # noqa: N802 def _infer_Reshape(self, node): # noqa: N802 shape_value = self._try_get_value(node, 1) vi = self.known_vi_[node.output[0]] + # allowzero (opset 14+) determines whether a 0 in the shape input means + # "copy the corresponding input dim" (allowzero=0, the legacy default) + # or "literal zero" (allowzero=1). Defaults to 0 when not present. + allow_zero = bool(get_attribute(node, "allowzero", 0)) if shape_value is None: shape_shape = self._get_shape(node, 1) assert len(shape_shape) == 1 @@ -1693,20 +1697,29 @@ def _infer_Reshape(self, node): # noqa: N802 for i, d in enumerate(shape_value): if type(d) is sympy.Symbol: new_sympy_shape.append(d) - elif d == 0: + non_deferred_size = non_deferred_size * d + elif d == 0 and not allow_zero: new_sympy_shape.append(input_sympy_shape[i]) non_deferred_size = non_deferred_size * input_sympy_shape[i] - else: + elif d == -1: new_sympy_shape.append(d) - if d == -1: deferred_dim_idx = i - elif d != 0: + else: + # explicit non-zero dim, or a literal 0 when allow_zero=True + new_sympy_shape.append(d) non_deferred_size = non_deferred_size * d assert new_sympy_shape.count(-1) < 2 if -1 in new_sympy_shape: - new_dim = total // non_deferred_size - new_sympy_shape[deferred_dim_idx] = new_dim + # When allow_zero is True a literal 0 contributes 0 to non_deferred_size, + # which would make total // non_deferred_size raise ZeroDivisionError. + # Per ONNX spec, combining allowzero=1 with -1 is invalid; emit a + # symbolic dim rather than crash the inference pass. + if non_deferred_size == 0: + new_sympy_shape[deferred_dim_idx] = self._new_symbolic_dim_from_output(node, 0, deferred_dim_idx) + else: + new_dim = total // non_deferred_size + new_sympy_shape[deferred_dim_idx] = new_dim self._update_computed_dims(new_sympy_shape) vi.CopyFrom( diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 0fdad07556db9..892bf20737a50 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -725,6 +725,87 @@ def test_qlinear_binary(self): ] self._check_shapes(graph, inferred.graph, expected_shapes) + def test_reshape_allowzero_default(self): + # allowzero is absent: a 0 in the shape tensor means "copy the + # corresponding input dim" (legacy ONNX <14 behaviour). + graph = helper.make_graph( + [helper.make_node("Reshape", ["input", "shape"], ["output"])], + "Reshape_AllowZero_Default", + [helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 4, 5])], + [helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5])], + [helper.make_tensor("shape", TensorProto.INT64, [3], [0, 4, 5])], + ) + model = helper.make_model(graph) + model.opset_import[0].version = 18 + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + expected_shapes = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_reshape_allowzero_zero(self): + # allowzero=0 explicit: same as the default — 0 means copy from input. + graph = helper.make_graph( + [helper.make_node("Reshape", ["input", "shape"], ["output"], allowzero=0)], + "Reshape_AllowZero_0", + [helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 4, 5])], + [helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5])], + [helper.make_tensor("shape", TensorProto.INT64, [3], [0, 4, 5])], + ) + model = helper.make_model(graph) + model.opset_import[0].version = 18 + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + expected_shapes = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_reshape_allowzero_one(self): + # allowzero=1: a 0 in the shape tensor means a literal zero dim. + # On a non-zero-element input this combination is invalid per the + # ONNX spec, so use a 0-element input to keep the model resolvable. + graph = helper.make_graph( + [helper.make_node("Reshape", ["input", "shape"], ["output"], allowzero=1)], + "Reshape_AllowZero_1", + [helper.make_tensor_value_info("input", TensorProto.FLOAT, [0, 4, 5])], + [helper.make_tensor_value_info("output", TensorProto.FLOAT, [0, 4, 5])], + [helper.make_tensor("shape", TensorProto.INT64, [3], [0, 4, 5])], + ) + model = helper.make_model(graph) + model.opset_import[0].version = 18 + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + expected_shapes = [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [0, 4, 5]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_reshape_allowzero_chained_zero_element(self): + # Regression for the chained-Reshape case reported in #28449. + # Reshape([0,8,2] -> [4,2,-1]) -> mid (4, 2, 0) -> Reshape([0,0,4], allowzero=1) + # The second Reshape must preserve the explicit zeros from its shape input + # instead of copying mid's dims (which would yield [4, 2, 4]). + graph = helper.make_graph( + [ + helper.make_node("Reshape", ["input", "shape1"], ["mid"]), + helper.make_node("Reshape", ["mid", "shape2"], ["output"], allowzero=1), + ], + "Reshape_AllowZero_Chained", + [helper.make_tensor_value_info("input", TensorProto.FLOAT, [0, 8, 2])], + [helper.make_tensor_value_info("output", TensorProto.FLOAT, [0, 0, 4])], + [ + helper.make_tensor("shape1", TensorProto.INT64, [3], [4, 2, -1]), + helper.make_tensor("shape2", TensorProto.INT64, [3], [0, 0, 4]), + ], + ) + model = helper.make_model(graph) + model.opset_import[0].version = 18 + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + expected_shapes = [ + helper.make_tensor_value_info("mid", TensorProto.FLOAT, [4, 2, 0]), + helper.make_tensor_value_info("output", TensorProto.FLOAT, [0, 0, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim): @@ -781,7 +862,7 @@ def make_concat_dims(concat_name, dims): graph_def = onnx.helper.make_graph(nodes, "graph", inputs, [output], initializer=initializers) model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def)) output = unique_element(model.graph.output) - shape = [d.dim_param if d.dim_param else d.dim_value for d in output.type.tensor_type.shape.dim] + shape = [d.dim_param or d.dim_value for d in output.type.tensor_type.shape.dim] self.assertEqual(shape, ["B", expected_output_dim]) def test_numeric_negative_indices_forward(self):