[BugFix][Relax][ONNX] Resolve param Vars in Concat to handle mixed Shape/Tensor inputs#19498
Conversation
…ape/Tensor inputs When `from_onnx(model, keep_params_in_input=True)` is used, every ONNX initializer becomes a `relax.Var` instead of a `relax.Constant`. The Concat handler's `is_shape_like()` check only recognizes `ShapeExpr` and 1D-int64 `Constant`, so a 1D-int64 shape value loaded as a Var is no longer recognized. When such a Var is concatenated with a `ShapeExpr` (the standard pattern for dynamic-batch Reshape in PyTorch-exported ONNX models), the heterogeneous Tuple(ShapeExpr, Tensor) is rejected by `relax.op.concat` with InternalError. Run each Concat input through the existing `get_constant` helper before the shape-like check; this resolves any Var that maps to a known param back to its baked Constant, restoring the all-shape-like fast path. Adds a regression test exercising the dynamic-batch Reshape pattern with `keep_params_in_input=True`.
There was a problem hiding this comment.
Code Review
This pull request updates the ONNX frontend's Concat operator to resolve input parameters to their constant values, ensuring that 1D-int64 shape values can correctly follow the shape-like fast path when keep_params_in_input is set to true. This change specifically addresses issues with dynamic-batch Reshape patterns in PyTorch-exported models. A new test case has been added to verify this fix. I have no further feedback to provide.
|
Hi @swjng However, This fixes the shape-construction case, but it can also change the semantics of ordinary tensor concat under Could we narrow the fix so param resolution is used only for the all-shape-like fast path, preferably without mutating |
…st path Address review feedback from @tlopex on apache#19498: the previous fix called get_constant on every input, which mutates graph_nodes and would fold a runtime weight into a Constant for ordinary Concat(input, weight) under keep_params_in_input=True. Replace with a local non-mutating peek that only resolves a param Var when it is a 1D int64 tensor, and only feed the resolved values into the shape-like fast path. The tensor fallback keeps the original inputs so runtime parameters remain runtime parameters. Add a regression test for ordinary Concat(input, weight) verifying the weight stays in main's param list and is detached as a real parameter.
|
Thanks. Narrowed it as suggested in the follow-up commit: a local non-mutating Added |
Description
When
from_onnx(model, keep_params_in_input=True)is used, every ONNX initializer becomes arelax.Varinstead of arelax.Constant. TheConcathandler'sis_shape_like()check only recognizesrelax.ShapeExprand 1D-int64relax.Constant, so a 1D-int64 shape value loaded as a Var is no longer recognized.When such a Var is concatenated with a
ShapeExpr— the standard pattern for dynamic-batchReshapein PyTorch-exported ONNX models — the heterogeneousTuple(ShapeExpr, Tensor)is rejected byrelax.op.concatwith:This effectively breaks
keep_params_in_input=Truefor any model with dynamic-batchReshape(extremely common in PyTorch ONNX exports).Fix
Run each
Concatinput through the existingget_constanthelper before theis_shape_likecheck. This resolves anyVarthat maps to a known param back to its bakedConstant, restoring the all-shape-like fast path.Minimal repro
An 8-node ONNX graph (
Shape→Slice→Concat([dyn_n, [12]])→Reshape) fails withkeep_params_in_input=Truebefore this PR and passes after. A regression test (test_concat_with_param_shape_value) covers this pattern.Testing
9 passed (1 new + 8 existing).