Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,46 @@ bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderIn
return true;
}

bool GatherOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/,
bool GatherOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
Comment thread
edgchen1 marked this conversation as resolved.
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
std::vector<int64_t> data_shape, indices_shape;
if (!GetShape(*node.InputDefs()[0], data_shape, logger)) {
if (!GetShape(*input_defs[0], data_shape, logger)) {
LOGS(logger, VERBOSE) << "Failed to get 'data' shape";
return false;
}

if (!GetShape(*node.InputDefs()[1], indices_shape, logger)) {
if (!GetShape(*input_defs[1], indices_shape, logger)) {
LOGS(logger, VERBOSE) << "Failed to get 'indices' shape";
return false;
}

// Don't allow scalar 'indices' input.
// We convert scalar inputs to tensors with shape [1] before providing them to CoreML.
// This modification changes the shape of the Gather output.
// Scalar (rank-0) 'indices' input.
//
// MIL `gather` accepts rank-0 indices and produces the correct ONNX output
// (the gathered axis is dropped). However, the CoreML EP reshapes any rank-0
// *graph-boundary* tensor to {1} when it crosses into the CoreML subgraph
// (see ModelBuilder::RegisterModelInputOutput) and a rank-0 input cannot be
// represented as an MLMultiArray. So we only allow scalar indices when they
// are a constant initializer: those flow through OnnxTensorToCoreMLTensor
// with rank preserved and MIL gather can consume them directly.
//
// On the NeuralNetwork path scalar initializers are also reshaped to {1}
// (ModelBuilder::RegisterInitializers, LoadConstantND requires rank >= 1),
// so the gather output shape ends up wrong there. Keep rejecting that case.
if (indices_shape.empty()) {
LOGS(logger, VERBOSE) << "Gather does not support scalar 'indices'";
return false;
if (!input_params.create_mlprogram) {
LOGS(logger, VERBOSE) << "Gather does not support scalar 'indices' on the NeuralNetwork path";
return false;
}
if (input_params.graph_viewer.GetConstantInitializer(input_defs[1]->Name()) == nullptr) {
LOGS(logger, VERBOSE) << "Gather with scalar 'indices' is only supported when 'indices' is a constant initializer";
return false;
}
}

// ONNX Gather output rank = data_rank + indices_rank - 1.
// For scalar indices (rank 0) this is data_rank - 1, which is what MIL also produces.
if (data_shape.size() + indices_shape.size() - 1 > 5) {
LOGS(logger, VERBOSE) << "Gather does not support output with rank greater than 5";
return false;
Expand Down
Loading