diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 9d65fe0e52da..268d91b7500a 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1014,6 +1014,7 @@ class Concat(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): axis = attr.get("axis", 0) + _, param_dict = params def is_shape_like(x: Any) -> bool: if isinstance(x, relax.ShapeExpr): @@ -1023,10 +1024,22 @@ def is_shape_like(x: Any) -> bool: else: return False + # Resolve 1D-int64 param Vars to constants only for the shape-like + # fast path; tensor fallback keeps the original Vars so runtime + # weights aren't folded under keep_params_in_input=True. + def resolve(x): + if isinstance(x, relax.Var) and x.name_hint in param_dict: + arr = param_dict[x.name_hint][1].numpy() + if arr.ndim == 1 and arr.dtype == _np.int64: + return relax.const(arr, "int64") + return x + + resolved = [resolve(inp) for inp in inputs] + # If all inputs are shape expr, perform computation directly. - if all([is_shape_like(inp) for inp in inputs]): + if all([is_shape_like(inp) for inp in resolved]): const_inputs = [] - for inp in inputs: + for inp in resolved: if isinstance(inp, relax.ShapeExpr): const_inputs.extend(inp.values) elif isinstance(inp, relax.Constant): diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index db68476609fb..5a8d84b0900c 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -29,7 +29,7 @@ import onnxruntime import pytest import tvm_ffi -from onnx import ModelProto, TensorProto, helper +from onnx import ModelProto, TensorProto, helper, numpy_helper import tvm import tvm.testing @@ -533,6 +533,57 @@ def test_concat(): verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0}) +def test_concat_with_param_shape_value(): + """Concat must handle a 1D-int64 initializer mixed with a ShapeExpr when + keep_params_in_input=True. Standard pattern in PyTorch-exported ONNX + models for dynamic-batch Reshape: Reshape(x, Concat(Shape(x)[:1], [12])).""" + inp = helper.make_tensor_value_info("x", TensorProto.FLOAT, ["N", 3, 4]) + out = helper.make_tensor_value_info("y", TensorProto.FLOAT, ["N", 12]) + twelve = numpy_helper.from_array(np.array([12], dtype=np.int64), "twelve") + starts = numpy_helper.from_array(np.array([0], dtype=np.int64), "starts") + ends = numpy_helper.from_array(np.array([1], dtype=np.int64), "ends") + nodes = [ + helper.make_node("Shape", ["x"], ["x_shape"]), + helper.make_node("Slice", ["x_shape", "starts", "ends"], ["dyn_n"]), + helper.make_node("Concat", ["dyn_n", "twelve"], ["new_shape"], axis=0), + helper.make_node("Reshape", ["x", "new_shape"], ["y"]), + ] + graph = helper.make_graph( + nodes, "concat_param_shape", [inp], [out], + initializer=[twelve, starts, ends], + ) + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 13)] + ) + model.ir_version = 8 + onnx.checker.check_model(model) + # Both modes should succeed; previously True crashed with + # "Op(relax.concat) expects the input to be a Tuple of Tensors". + from_onnx(model, keep_params_in_input=False) + from_onnx(model, keep_params_in_input=True) + + +def test_concat_with_param_tensor_keeps_runtime_param(): + """Concat(input, weight) under keep_params_in_input=True must keep `weight` + as a runtime param, not fold it into a constant.""" + weight_np = np.arange(8, dtype=np.float32).reshape(2, 4) + graph = helper.make_graph( + [helper.make_node("Concat", ["x", "w"], ["y"], axis=0)], + "concat_param_tensor", + [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4])], + [helper.make_tensor_value_info("y", TensorProto.FLOAT, [4, 4])], + initializer=[numpy_helper.from_array(weight_np, "w")], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 8 + onnx.checker.check_model(model) + + mod, params = relax.frontend.detach_params(from_onnx(model, keep_params_in_input=True)) + assert "w" in [p.name_hint for p in mod["main"].params] + assert len(params["main"]) == 1 + np.testing.assert_array_equal(params["main"][0].numpy(), weight_np) + + @pytest.mark.parametrize("op_name", ["Add", "Sub", "Mul", "Div", "Pow"]) def test_binary(op_name: str): verify_binary(op_name, [1, 32], [1, 32], [1, 32])