Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type
)
if dst_tensor_type.elem_type != src_tensor_type.elem_type:
node_id = node.name if node.name else node.op_type
node_id = node.name or node.op_type
raise ValueError(
f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
Expand Down Expand Up @@ -1436,7 +1436,7 @@ def _infer_aten_multinomial(self, node):
assert rank in [1, 2]
num_samples = self._try_get_value(node, 1)
di = rank - 1
last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di))
last_dim = num_samples or str(self._new_symbolic_dim_from_output(node, 0, di))
output_shape = [*sympy_shape[:-1], last_dim]
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(
Expand Down Expand Up @@ -1670,6 +1670,10 @@ def _infer_RelativePositionBias(self, node): # noqa: N802
def _infer_Reshape(self, node): # noqa: N802
shape_value = self._try_get_value(node, 1)
vi = self.known_vi_[node.output[0]]
# allowzero (opset 14+) determines whether a 0 in the shape input means
# "copy the corresponding input dim" (allowzero=0, the legacy default)
# or "literal zero" (allowzero=1). Defaults to 0 when not present.
allow_zero = bool(get_attribute(node, "allowzero", 0))
if shape_value is None:
shape_shape = self._get_shape(node, 1)
assert len(shape_shape) == 1
Expand All @@ -1693,20 +1697,29 @@ def _infer_Reshape(self, node): # noqa: N802
for i, d in enumerate(shape_value):
if type(d) is sympy.Symbol:
new_sympy_shape.append(d)
elif d == 0:
non_deferred_size = non_deferred_size * d
elif d == 0 and not allow_zero:
new_sympy_shape.append(input_sympy_shape[i])
non_deferred_size = non_deferred_size * input_sympy_shape[i]
else:
elif d == -1:
new_sympy_shape.append(d)
if d == -1:
deferred_dim_idx = i
elif d != 0:
else:
# explicit non-zero dim, or a literal 0 when allow_zero=True
new_sympy_shape.append(d)
non_deferred_size = non_deferred_size * d

assert new_sympy_shape.count(-1) < 2
if -1 in new_sympy_shape:
new_dim = total // non_deferred_size
new_sympy_shape[deferred_dim_idx] = new_dim
# When allow_zero is True a literal 0 contributes 0 to non_deferred_size,
# which would make total // non_deferred_size raise ZeroDivisionError.
# Per ONNX spec, combining allowzero=1 with -1 is invalid; emit a
# symbolic dim rather than crash the inference pass.
if non_deferred_size == 0:
new_sympy_shape[deferred_dim_idx] = self._new_symbolic_dim_from_output(node, 0, deferred_dim_idx)
else:
new_dim = total // non_deferred_size
new_sympy_shape[deferred_dim_idx] = new_dim

self._update_computed_dims(new_sympy_shape)
vi.CopyFrom(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,87 @@
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_reshape_allowzero_default(self):
# allowzero is absent: a 0 in the shape tensor means "copy the
# corresponding input dim" (legacy ONNX <14 behaviour).

Check warning on line 730 in onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "behaviour" is a misspelling of "behavior" Raw Output: ./onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py:730:52: "behaviour" is a misspelling of "behavior"
graph = helper.make_graph(
[helper.make_node("Reshape", ["input", "shape"], ["output"])],
"Reshape_AllowZero_Default",
[helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 4, 5])],
[helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5])],
[helper.make_tensor("shape", TensorProto.INT64, [3], [0, 4, 5])],
)
model = helper.make_model(graph)
model.opset_import[0].version = 18
inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)
expected_shapes = [
helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_reshape_allowzero_zero(self):
# allowzero=0 explicit: same as the default — 0 means copy from input.
graph = helper.make_graph(
[helper.make_node("Reshape", ["input", "shape"], ["output"], allowzero=0)],
"Reshape_AllowZero_0",
[helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 4, 5])],
[helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5])],
[helper.make_tensor("shape", TensorProto.INT64, [3], [0, 4, 5])],
)
model = helper.make_model(graph)
model.opset_import[0].version = 18
inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)
expected_shapes = [
helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 4, 5]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_reshape_allowzero_one(self):
# allowzero=1: a 0 in the shape tensor means a literal zero dim.
# On a non-zero-element input this combination is invalid per the
# ONNX spec, so use a 0-element input to keep the model resolvable.
graph = helper.make_graph(
[helper.make_node("Reshape", ["input", "shape"], ["output"], allowzero=1)],
"Reshape_AllowZero_1",
[helper.make_tensor_value_info("input", TensorProto.FLOAT, [0, 4, 5])],
[helper.make_tensor_value_info("output", TensorProto.FLOAT, [0, 4, 5])],
[helper.make_tensor("shape", TensorProto.INT64, [3], [0, 4, 5])],
)
model = helper.make_model(graph)
model.opset_import[0].version = 18
inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)
expected_shapes = [
helper.make_tensor_value_info("output", TensorProto.FLOAT, [0, 4, 5]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_reshape_allowzero_chained_zero_element(self):
# Regression for the chained-Reshape case reported in #28449.
# Reshape([0,8,2] -> [4,2,-1]) -> mid (4, 2, 0) -> Reshape([0,0,4], allowzero=1)
# The second Reshape must preserve the explicit zeros from its shape input
# instead of copying mid's dims (which would yield [4, 2, 4]).
graph = helper.make_graph(
[
helper.make_node("Reshape", ["input", "shape1"], ["mid"]),
helper.make_node("Reshape", ["mid", "shape2"], ["output"], allowzero=1),
],
"Reshape_AllowZero_Chained",
[helper.make_tensor_value_info("input", TensorProto.FLOAT, [0, 8, 2])],
[helper.make_tensor_value_info("output", TensorProto.FLOAT, [0, 0, 4])],
[
helper.make_tensor("shape1", TensorProto.INT64, [3], [4, 2, -1]),
helper.make_tensor("shape2", TensorProto.INT64, [3], [0, 0, 4]),
],
)
model = helper.make_model(graph)
model.opset_import[0].version = 18
inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)
expected_shapes = [
helper.make_tensor_value_info("mid", TensorProto.FLOAT, [4, 2, 0]),
helper.make_tensor_value_info("output", TensorProto.FLOAT, [0, 0, 4]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)


class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):
Expand Down Expand Up @@ -781,7 +862,7 @@
graph_def = onnx.helper.make_graph(nodes, "graph", inputs, [output], initializer=initializers)
model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def))
output = unique_element(model.graph.output)
shape = [d.dim_param if d.dim_param else d.dim_value for d in output.type.tensor_type.shape.dim]
shape = [d.dim_param or d.dim_value for d in output.type.tensor_type.shape.dim]
self.assertEqual(shape, ["B", expected_output_dim])

def test_numeric_negative_indices_forward(self):
Expand Down
Loading