Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,6 +1014,7 @@ class Concat(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", 0)
_, param_dict = params

def is_shape_like(x: Any) -> bool:
if isinstance(x, relax.ShapeExpr):
Expand All @@ -1023,10 +1024,22 @@ def is_shape_like(x: Any) -> bool:
else:
return False

# Resolve 1D-int64 param Vars to constants only for the shape-like
# fast path; tensor fallback keeps the original Vars so runtime
# weights aren't folded under keep_params_in_input=True.
def resolve(x):
if isinstance(x, relax.Var) and x.name_hint in param_dict:
arr = param_dict[x.name_hint][1].numpy()
if arr.ndim == 1 and arr.dtype == _np.int64:
return relax.const(arr, "int64")
return x

resolved = [resolve(inp) for inp in inputs]

# If all inputs are shape expr, perform computation directly.
if all([is_shape_like(inp) for inp in inputs]):
if all([is_shape_like(inp) for inp in resolved]):
const_inputs = []
for inp in inputs:
for inp in resolved:
if isinstance(inp, relax.ShapeExpr):
const_inputs.extend(inp.values)
elif isinstance(inp, relax.Constant):
Expand Down
53 changes: 52 additions & 1 deletion tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import onnxruntime
import pytest
import tvm_ffi
from onnx import ModelProto, TensorProto, helper
from onnx import ModelProto, TensorProto, helper, numpy_helper

import tvm
import tvm.testing
Expand Down Expand Up @@ -533,6 +533,57 @@ def test_concat():
verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0})


def test_concat_with_param_shape_value():
"""Concat must handle a 1D-int64 initializer mixed with a ShapeExpr when
keep_params_in_input=True. Standard pattern in PyTorch-exported ONNX
models for dynamic-batch Reshape: Reshape(x, Concat(Shape(x)[:1], [12]))."""
inp = helper.make_tensor_value_info("x", TensorProto.FLOAT, ["N", 3, 4])
out = helper.make_tensor_value_info("y", TensorProto.FLOAT, ["N", 12])
twelve = numpy_helper.from_array(np.array([12], dtype=np.int64), "twelve")
starts = numpy_helper.from_array(np.array([0], dtype=np.int64), "starts")
ends = numpy_helper.from_array(np.array([1], dtype=np.int64), "ends")
nodes = [
helper.make_node("Shape", ["x"], ["x_shape"]),
helper.make_node("Slice", ["x_shape", "starts", "ends"], ["dyn_n"]),
helper.make_node("Concat", ["dyn_n", "twelve"], ["new_shape"], axis=0),
helper.make_node("Reshape", ["x", "new_shape"], ["y"]),
]
graph = helper.make_graph(
nodes, "concat_param_shape", [inp], [out],
initializer=[twelve, starts, ends],
)
model = helper.make_model(
graph, opset_imports=[helper.make_opsetid("", 13)]
)
model.ir_version = 8
onnx.checker.check_model(model)
# Both modes should succeed; previously True crashed with
# "Op(relax.concat) expects the input to be a Tuple of Tensors".
from_onnx(model, keep_params_in_input=False)
from_onnx(model, keep_params_in_input=True)


def test_concat_with_param_tensor_keeps_runtime_param():
"""Concat(input, weight) under keep_params_in_input=True must keep `weight`
as a runtime param, not fold it into a constant."""
weight_np = np.arange(8, dtype=np.float32).reshape(2, 4)
graph = helper.make_graph(
[helper.make_node("Concat", ["x", "w"], ["y"], axis=0)],
"concat_param_tensor",
[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4])],
[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4, 4])],
initializer=[numpy_helper.from_array(weight_np, "w")],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
model.ir_version = 8
onnx.checker.check_model(model)

mod, params = relax.frontend.detach_params(from_onnx(model, keep_params_in_input=True))
assert "w" in [p.name_hint for p in mod["main"].params]
assert len(params["main"]) == 1
np.testing.assert_array_equal(params["main"][0].numpy(), weight_np)


@pytest.mark.parametrize("op_name", ["Add", "Sub", "Mul", "Div", "Pow"])
def test_binary(op_name: str):
verify_binary(op_name, [1, 32], [1, 32], [1, 32])
Expand Down
Loading