-
Notifications
You must be signed in to change notification settings - Fork 3.9k
[CoreML EP] Support Gather with scalar 'indices' #28278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
8d966dd
df9481a
9cb8aad
0509992
d2dde53
e0171ae
690b1d4
ee743fa
887448d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,27 +30,121 @@ | |
| } // namespace | ||
|
|
||
| Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, | ||
| const logging::Logger& /*logger*/) const { | ||
| const logging::Logger& logger) const { | ||
| const auto axis = GetAxisAttribute(node); | ||
| const auto& data_def = *node.InputDefs()[0]; | ||
| const auto& indices_def = *node.InputDefs()[1]; | ||
| const auto& output_def = *node.OutputDefs()[0]; | ||
|
|
||
| std::vector<int64_t> data_shape, indices_shape; | ||
| ORT_RETURN_IF_NOT(GetShape(data_def, data_shape, logger), "Failed to get 'data' shape"); | ||
| ORT_RETURN_IF_NOT(GetShape(indices_def, indices_shape, logger), "Failed to get 'indices' shape"); | ||
|
|
||
| // ONNX Gather: out_shape = data_shape[:axis] + indices_shape + data_shape[axis+1:] | ||
| // CoreML's gather requires rank-1+ indices, so for scalar indices we promote | ||
| // them to [1], gather, and then squeeze the resulting axis to restore the | ||
| // original output rank. The positive axis after wrapping is needed for the | ||
| // squeeze axis below regardless of path. | ||
| const bool scalar_indices = indices_shape.empty(); | ||
| const int64_t pos_axis = HandleNegativeAxis(axis, data_shape.size()); | ||
|
|
||
| if (model_builder.CreateMLProgram()) { | ||
| using CoreML::Specification::MILSpec::Operation; | ||
| std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "gather"); | ||
|
|
||
| const auto axis = GetAxisAttribute(node); | ||
| // IsOpSupportedImpl gates indices to INT32 or INT64, so we can pass the | ||
| // dtype straight through to the reshape's intermediate output. | ||
| int32_t indices_dtype{}; | ||
| ORT_RETURN_IF_NOT(GetType(indices_def, indices_dtype, logger), | ||
| "Failed to get 'indices' dtype"); | ||
| const int32_t output_dtype = static_cast<int32_t>(output_def.TypeAsProto()->tensor_type().elem_type()); | ||
|
|
||
| std::string indices_name = indices_def.Name(); | ||
|
|
||
| if (scalar_indices) { | ||
| // [] -> [1] via reshape. We use reshape rather than expand_dims because | ||
| // CoreML internally pads scalars; expand_dims on the padded tensor can | ||
| // push the apparent rank past the rank-5 limit on high-rank `data`. | ||
| auto reshape = model_builder.CreateOperation(node, "reshape", "indices"); | ||
| AddOperationInput(*reshape, "x", indices_def.Name()); | ||
| const std::vector<int64_t> indices_1d_shape = {1}; | ||
| AddOperationInput(*reshape, "shape", | ||
| model_builder.AddConstant(reshape->type(), "shape", indices_1d_shape)); | ||
|
|
||
| indices_name = model_builder.GetUniqueName(node, "indices_1d"); | ||
| AddIntermediateOperationOutput(*reshape, indices_name, indices_dtype, indices_1d_shape); | ||
| model_builder.AddOperation(std::move(reshape)); | ||
| } | ||
|
|
||
| std::unique_ptr<Operation> gather = model_builder.CreateOperation(node, "gather"); | ||
|
Check warning on line 77 in onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc
|
||
| // coreml docs claims validate_indices is optional but in practice it is required | ||
| const auto validate_indices = false; | ||
| AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); // data | ||
| AddOperationInput(*op, "indices", node.InputDefs()[1]->Name()); // indices | ||
| AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); // axis attr | ||
| AddOperationInput(*op, "validate_indices", model_builder.AddScalarConstant(op->type(), "validate_indices", validate_indices)); | ||
| AddOperationOutput(*op, *node.OutputDefs()[0]); // output | ||
| model_builder.AddOperation(std::move(op)); | ||
| AddOperationInput(*gather, "x", data_def.Name()); | ||
| AddOperationInput(*gather, "indices", indices_name); | ||
| AddOperationInput(*gather, "axis", model_builder.AddScalarConstant(gather->type(), "axis", axis)); | ||
| AddOperationInput(*gather, "validate_indices", | ||
| model_builder.AddScalarConstant(gather->type(), "validate_indices", validate_indices)); | ||
|
|
||
| if (!scalar_indices) { | ||
| AddOperationOutput(*gather, output_def); | ||
| model_builder.AddOperation(std::move(gather)); | ||
| } else { | ||
| // gather output here has the data's rank (one more than ONNX scalar-gather output); | ||
| // squeeze the inserted axis to recover the original output shape. | ||
| std::vector<int64_t> gather_shape = data_shape; | ||
| gather_shape[pos_axis] = 1; | ||
| const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_out"); | ||
| AddIntermediateOperationOutput(*gather, gather_out_name, output_dtype, gather_shape); | ||
| model_builder.AddOperation(std::move(gather)); | ||
|
|
||
| auto squeeze = model_builder.CreateOperation(node, "squeeze", "post"); | ||
| AddOperationInput(*squeeze, "x", gather_out_name); | ||
| const std::vector<int64_t> sq_axes = {pos_axis}; | ||
| AddOperationInput(*squeeze, "axes", model_builder.AddConstant(squeeze->type(), "axes", sq_axes)); | ||
| AddOperationOutput(*squeeze, output_def); | ||
| model_builder.AddOperation(std::move(squeeze)); | ||
| } | ||
| } else { | ||
| auto layer = model_builder.CreateNNLayer(node); | ||
| layer->mutable_gather()->set_axis(GetAxisAttribute(node)); | ||
| *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data | ||
| *layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices | ||
| *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); // output | ||
| model_builder.AddLayer(std::move(layer)); | ||
| if (!scalar_indices) { | ||
| auto layer = model_builder.CreateNNLayer(node); | ||
| layer->mutable_gather()->set_axis(axis); | ||
| *layer->mutable_input()->Add() = data_def.Name(); | ||
| *layer->mutable_input()->Add() = indices_def.Name(); | ||
| *layer->mutable_output()->Add() = output_def.Name(); | ||
| model_builder.AddLayer(std::move(layer)); | ||
| } else { | ||
| // expand_dims indices: [] -> [1]. Unlike the MLProgram reshape path | ||
| // above, NN's expand_dims doesn't internally pad rank, so we don't run | ||
| // into the apparent-rank inflation that forced reshape+gather there; | ||
| // expand_dims is the natural choice on this path. | ||
| const std::string& indices_1d_name = model_builder.GetUniqueName(node, "indices_1d"); | ||
| { | ||
| auto expand_layer = model_builder.CreateNNLayer(node, "_indices_expand"); | ||
| expand_layer->mutable_expanddims()->add_axes(0); | ||
| *expand_layer->mutable_input()->Add() = indices_def.Name(); | ||
| *expand_layer->mutable_output()->Add() = indices_1d_name; | ||
| model_builder.AddLayer(std::move(expand_layer)); | ||
| } | ||
|
|
||
| // gather with the promoted indices | ||
| const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_out"); | ||
| { | ||
| auto gather_layer = model_builder.CreateNNLayer(node); | ||
| gather_layer->mutable_gather()->set_axis(axis); | ||
| *gather_layer->mutable_input()->Add() = data_def.Name(); | ||
| *gather_layer->mutable_input()->Add() = indices_1d_name; | ||
| *gather_layer->mutable_output()->Add() = gather_out_name; | ||
| model_builder.AddLayer(std::move(gather_layer)); | ||
| } | ||
|
|
||
| // squeeze the inserted axis | ||
| { | ||
| auto squeeze_layer = model_builder.CreateNNLayer(node, "_post_squeeze"); | ||
| squeeze_layer->mutable_squeeze()->add_axes(pos_axis); | ||
| squeeze_layer->mutable_squeeze()->set_squeezeall(false); | ||
| *squeeze_layer->mutable_input()->Add() = gather_out_name; | ||
| *squeeze_layer->mutable_output()->Add() = output_def.Name(); | ||
| model_builder.AddLayer(std::move(squeeze_layer)); | ||
|
Check warning on line 145 in onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc
|
||
| } | ||
| } | ||
| } | ||
| return Status::OK(); | ||
| } | ||
|
|
@@ -87,14 +181,45 @@ | |
| return false; | ||
| } | ||
|
|
||
| // Don't allow scalar 'indices' input. | ||
| // We convert scalar inputs to tensors with shape [1] before providing them to CoreML. | ||
| // This modification changes the shape of the Gather output. | ||
| if (indices_shape.empty()) { | ||
| LOGS(logger, VERBOSE) << "Gather does not support scalar 'indices'"; | ||
| // ONNX Gather schema constrains indices to int32 or int64. Validate here so | ||
| // AddToModelBuilderImpl can trust the dtype rather than silently defaulting | ||
| // on an unexpected value. | ||
| int32_t indices_dtype{}; | ||
| if (!GetType(*node.InputDefs()[1], indices_dtype, logger)) { | ||
| return false; | ||
| } | ||
| if (indices_dtype != ONNX_NAMESPACE::TensorProto_DataType_INT32 && | ||
| indices_dtype != ONNX_NAMESPACE::TensorProto_DataType_INT64) { | ||
| LOGS(logger, VERBOSE) << "Gather 'indices' dtype [" << indices_dtype | ||
| << "] is not supported (expected INT32 or INT64)"; | ||
| return false; | ||
| } | ||
|
|
||
| // For scalar indices we internally emit gather with promoted [1] indices | ||
| // then squeeze. That requires us to claim a static intermediate shape, so | ||
| // we only handle scalar indices when the data shape itself is fully | ||
| // static. (Dynamic-shape scalar Gather still falls back to CPU.) | ||
| if (indices_shape.empty()) { | ||
| if (!IsStaticShape(data_shape)) { | ||
| LOGS(logger, VERBOSE) << "Gather with scalar 'indices' requires static 'data' shape"; | ||
| return false; | ||
| } | ||
| // The pre-squeeze intermediate has the same rank as `data`. CoreML's | ||
| // compiler reports "Invalid rank: 6" when a rank-5 intermediate is | ||
| // produced via reshape+gather, even though rank-5 intermediates are | ||
| // accepted in other op chains. Cap scalar-indices Gather at data rank 4 | ||
| // until that compiler limit is lifted. | ||
| // | ||
| // TODO: re-test on newer macOS / CoreML versions; if Apple lifts the | ||
|
Check warning on line 213 in onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc
|
||
| // intermediate rank limit, this cap can be raised to 5 (matching the | ||
| // general Gather output-rank check below). | ||
| if (data_shape.size() > 4) { | ||
| LOGS(logger, VERBOSE) << "Gather with scalar 'indices' supports 'data' rank up to 4"; | ||
| return false; | ||
| } | ||
| } | ||
|
|
||
| // Output rank = data_rank + indices_rank - 1. The rank-5 limit applies. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential underflow in output-rank check (line ~225 in the original): data_shape.size() + indices_shape.size() - 1 — when indices is scalar (indices_shape.size() == 0) and data rank is 1, this is 1 + 0 - 1 = 0, which is fine. But for size_t arithmetic, if somehow both were 0 (can't happen since data rank >= 1 is checked), this would underflow. Not a real bug since data rank >= 1 is enforced by ONNX schema and GetShape succeeds, but worth noting the unsigned arithmetic. |
||
| if (data_shape.size() + indices_shape.size() - 1 > 5) { | ||
| LOGS(logger, VERBOSE) << "Gather does not support output with rank greater than 5"; | ||
| return false; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer TensorShapeVector for inline storage.