Skip to content
167 changes: 146 additions & 21 deletions onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,121 @@
} // namespace

Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& /*logger*/) const {
const logging::Logger& logger) const {
const auto axis = GetAxisAttribute(node);
const auto& data_def = *node.InputDefs()[0];
const auto& indices_def = *node.InputDefs()[1];
const auto& output_def = *node.OutputDefs()[0];

std::vector<int64_t> data_shape, indices_shape;
ORT_RETURN_IF_NOT(GetShape(data_def, data_shape, logger), "Failed to get 'data' shape");
ORT_RETURN_IF_NOT(GetShape(indices_def, indices_shape, logger), "Failed to get 'indices' shape");

// ONNX Gather: out_shape = data_shape[:axis] + indices_shape + data_shape[axis+1:]
// CoreML's gather requires rank-1+ indices, so for scalar indices we promote
// them to [1], gather, and then squeeze the resulting axis to restore the
// original output rank. The positive axis after wrapping is needed for the
// squeeze axis below regardless of path.
const bool scalar_indices = indices_shape.empty();
const int64_t pos_axis = HandleNegativeAxis(axis, data_shape.size());

if (model_builder.CreateMLProgram()) {
using CoreML::Specification::MILSpec::Operation;
std::unique_ptr<Operation> op = model_builder.CreateOperation(node, "gather");

const auto axis = GetAxisAttribute(node);
// IsOpSupportedImpl gates indices to INT32 or INT64, so we can pass the
// dtype straight through to the reshape's intermediate output.
int32_t indices_dtype{};
ORT_RETURN_IF_NOT(GetType(indices_def, indices_dtype, logger),
"Failed to get 'indices' dtype");
const int32_t output_dtype = static_cast<int32_t>(output_def.TypeAsProto()->tensor_type().elem_type());

std::string indices_name = indices_def.Name();

if (scalar_indices) {
// [] -> [1] via reshape. We use reshape rather than expand_dims because
// CoreML internally pads scalars; expand_dims on the padded tensor can
// push the apparent rank past the rank-5 limit on high-rank `data`.
auto reshape = model_builder.CreateOperation(node, "reshape", "indices");
AddOperationInput(*reshape, "x", indices_def.Name());
const std::vector<int64_t> indices_1d_shape = {1};
AddOperationInput(*reshape, "shape",
model_builder.AddConstant(reshape->type(), "shape", indices_1d_shape));

indices_name = model_builder.GetUniqueName(node, "indices_1d");
AddIntermediateOperationOutput(*reshape, indices_name, indices_dtype, indices_1d_shape);
model_builder.AddOperation(std::move(reshape));
}

std::unique_ptr<Operation> gather = model_builder.CreateOperation(node, "gather");

Check warning on line 77 in onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc:77: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
// coreml docs claims validate_indices is optional but in practice it is required
const auto validate_indices = false;
AddOperationInput(*op, "x", node.InputDefs()[0]->Name()); // data
AddOperationInput(*op, "indices", node.InputDefs()[1]->Name()); // indices
AddOperationInput(*op, "axis", model_builder.AddScalarConstant(op->type(), "axis", axis)); // axis attr
AddOperationInput(*op, "validate_indices", model_builder.AddScalarConstant(op->type(), "validate_indices", validate_indices));
AddOperationOutput(*op, *node.OutputDefs()[0]); // output
model_builder.AddOperation(std::move(op));
AddOperationInput(*gather, "x", data_def.Name());
AddOperationInput(*gather, "indices", indices_name);
AddOperationInput(*gather, "axis", model_builder.AddScalarConstant(gather->type(), "axis", axis));
AddOperationInput(*gather, "validate_indices",
model_builder.AddScalarConstant(gather->type(), "validate_indices", validate_indices));

if (!scalar_indices) {
AddOperationOutput(*gather, output_def);
model_builder.AddOperation(std::move(gather));
} else {
// gather output here has the data's rank (one more than ONNX scalar-gather output);
// squeeze the inserted axis to recover the original output shape.
std::vector<int64_t> gather_shape = data_shape;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer TensorShapeVector for inline storage.

gather_shape[pos_axis] = 1;
const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_out");
AddIntermediateOperationOutput(*gather, gather_out_name, output_dtype, gather_shape);
model_builder.AddOperation(std::move(gather));

auto squeeze = model_builder.CreateOperation(node, "squeeze", "post");
AddOperationInput(*squeeze, "x", gather_out_name);
const std::vector<int64_t> sq_axes = {pos_axis};
AddOperationInput(*squeeze, "axes", model_builder.AddConstant(squeeze->type(), "axes", sq_axes));
AddOperationOutput(*squeeze, output_def);
model_builder.AddOperation(std::move(squeeze));
}
} else {
auto layer = model_builder.CreateNNLayer(node);
layer->mutable_gather()->set_axis(GetAxisAttribute(node));
*layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data
*layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices
*layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); // output
model_builder.AddLayer(std::move(layer));
if (!scalar_indices) {
auto layer = model_builder.CreateNNLayer(node);
layer->mutable_gather()->set_axis(axis);
*layer->mutable_input()->Add() = data_def.Name();
*layer->mutable_input()->Add() = indices_def.Name();
*layer->mutable_output()->Add() = output_def.Name();
model_builder.AddLayer(std::move(layer));
} else {
// expand_dims indices: [] -> [1]. Unlike the MLProgram reshape path
// above, NN's expand_dims doesn't internally pad rank, so we don't run
// into the apparent-rank inflation that forced reshape+gather there;
// expand_dims is the natural choice on this path.
const std::string& indices_1d_name = model_builder.GetUniqueName(node, "indices_1d");
{
auto expand_layer = model_builder.CreateNNLayer(node, "_indices_expand");
expand_layer->mutable_expanddims()->add_axes(0);
*expand_layer->mutable_input()->Add() = indices_def.Name();
*expand_layer->mutable_output()->Add() = indices_1d_name;
model_builder.AddLayer(std::move(expand_layer));
}

// gather with the promoted indices
const std::string& gather_out_name = model_builder.GetUniqueName(node, "gather_out");
{
auto gather_layer = model_builder.CreateNNLayer(node);
gather_layer->mutable_gather()->set_axis(axis);
*gather_layer->mutable_input()->Add() = data_def.Name();
*gather_layer->mutable_input()->Add() = indices_1d_name;
*gather_layer->mutable_output()->Add() = gather_out_name;
model_builder.AddLayer(std::move(gather_layer));
}

// squeeze the inserted axis
{
auto squeeze_layer = model_builder.CreateNNLayer(node, "_post_squeeze");
squeeze_layer->mutable_squeeze()->add_axes(pos_axis);
squeeze_layer->mutable_squeeze()->set_squeezeall(false);
*squeeze_layer->mutable_input()->Add() = gather_out_name;
*squeeze_layer->mutable_output()->Add() = output_def.Name();
model_builder.AddLayer(std::move(squeeze_layer));

Check warning on line 145 in onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc:145: Add #include <utility> for move [build/include_what_you_use] [4]
}
}
}
return Status::OK();
}
Expand Down Expand Up @@ -87,14 +181,45 @@
return false;
}

// Don't allow scalar 'indices' input.
// We convert scalar inputs to tensors with shape [1] before providing them to CoreML.
// This modification changes the shape of the Gather output.
if (indices_shape.empty()) {
LOGS(logger, VERBOSE) << "Gather does not support scalar 'indices'";
// ONNX Gather schema constrains indices to int32 or int64. Validate here so
// AddToModelBuilderImpl can trust the dtype rather than silently defaulting
// on an unexpected value.
int32_t indices_dtype{};
if (!GetType(*node.InputDefs()[1], indices_dtype, logger)) {
return false;
}
if (indices_dtype != ONNX_NAMESPACE::TensorProto_DataType_INT32 &&
indices_dtype != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
LOGS(logger, VERBOSE) << "Gather 'indices' dtype [" << indices_dtype
<< "] is not supported (expected INT32 or INT64)";
return false;
}

// For scalar indices we internally emit gather with promoted [1] indices
// then squeeze. That requires us to claim a static intermediate shape, so
// we only handle scalar indices when the data shape itself is fully
// static. (Dynamic-shape scalar Gather still falls back to CPU.)
if (indices_shape.empty()) {
if (!IsStaticShape(data_shape)) {
LOGS(logger, VERBOSE) << "Gather with scalar 'indices' requires static 'data' shape";
return false;
}
// The pre-squeeze intermediate has the same rank as `data`. CoreML's
// compiler reports "Invalid rank: 6" when a rank-5 intermediate is
// produced via reshape+gather, even though rank-5 intermediates are
// accepted in other op chains. Cap scalar-indices Gather at data rank 4
// until that compiler limit is lifted.
//
// TODO: re-test on newer macOS / CoreML versions; if Apple lifts the

Check warning on line 213 in onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc:213: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// intermediate rank limit, this cap can be raised to 5 (matching the
// general Gather output-rank check below).
if (data_shape.size() > 4) {
LOGS(logger, VERBOSE) << "Gather with scalar 'indices' supports 'data' rank up to 4";
return false;
}
}

// Output rank = data_rank + indices_rank - 1. The rank-5 limit applies.
Copy link
Copy Markdown
Member

@yuslepukhin yuslepukhin May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential underflow in output-rank check (line ~225 in the original): data_shape.size() + indices_shape.size() - 1 — when indices is scalar (indices_shape.size() == 0) and data rank is 1, this is 1 + 0 - 1 = 0, which is fine. But for size_t arithmetic, if somehow both were 0 (can't happen since data rank >= 1 is checked), this would underflow. Not a real bug since data rank >= 1 is enforced by ONNX schema and GetShape succeeds, but worth noting the unsigned arithmetic.
This, however, depends on ONNX enforcement code.

if (data_shape.size() + indices_shape.size() - 1 > 5) {
LOGS(logger, VERBOSE) << "Gather does not support output with rank greater than 5";
return false;
Expand Down
Loading
Loading