Skip to content
18 changes: 12 additions & 6 deletions onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,19 @@ Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model

bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params,
const logging::Logger& logger) const {
if (input_params.create_mlprogram) {
// The ML Program 'cast' op stands alone, so a Cast fed directly by a graph
// input (no preceding node) is fine here.
return true;
}
Comment thread
yuslepukhin marked this conversation as resolved.

// The NeuralNetwork path only supports a Cast that consumes an ArgMax, so it
// needs a preceding node to inspect (InputEdgesBegin() must be dereferenceable).
if (node.GetInputEdgesCount() == 0) {
LOGS(logger, VERBOSE) << "Cast has no preceding nodes.";
return false;
}

if (input_params.create_mlprogram) {
return true;
}

const auto& prec_node = node.InputEdgesBegin()->GetNode();

/*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax
Expand Down Expand Up @@ -141,11 +145,13 @@ bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] co
if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) &&
input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
input_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL) &&
(output_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 ||
output_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 ||
output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT ||
output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) {
output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 ||
output_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL)) {
return true;
} else {
LOGS(logger, VERBOSE) << "[" << node.OpType()
Expand Down
64 changes: 64 additions & 0 deletions onnxruntime/test/providers/coreml/coreml_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2361,6 +2361,48 @@ TEST(CoreMLExecutionProviderTest, Split11SingleOutputNotSupported) {
}

namespace {
// int64 -> Cast(bool) -> Cast(float) [-> Sqrt]; the first Cast is fed directly
// by a graph input (no preceding node).
//
// append_nontrivial=false gives the all-Cast graph used by the NeuralNetwork
// negative test below. append_nontrivial=true appends a Sqrt: a CoreML partition
// made up only of trivial ops (Cast is marked trivial) is dropped, so the extra
// non-trivial op keeps the partition and lets the test below assert the bool
// Casts are claimed.
std::string MakeCastBoolModelData(bool append_nontrivial = false) {
onnxruntime::Model model("cast_bool_test", false, DefaultLoggingManager().DefaultLogger());
auto& graph = model.MainGraph();

auto make_type = [](int32_t elem_type) {
ONNX_NAMESPACE::TypeProto t;
t.mutable_tensor_type()->set_elem_type(elem_type);
for (int64_t d : {1, 4}) t.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(d);
return t;
};
const auto int64_type = make_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
const auto bool_type = make_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL);
const auto float_type = make_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);

auto& x = graph.GetOrCreateNodeArg("X", &int64_type);
auto& b = graph.GetOrCreateNodeArg("B", &bool_type);
auto& y = graph.GetOrCreateNodeArg("Y", &float_type);

auto& to_bool = graph.AddNode("cast_to_bool", "Cast", "int64 -> bool", {&x}, {&b});
to_bool.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_BOOL));
auto& to_float = graph.AddNode("cast_to_float", "Cast", "bool -> float", {&b}, {&y});
to_float.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));

if (append_nontrivial) {
auto& z = graph.GetOrCreateNodeArg("Z", &float_type);
graph.AddNode("sqrt", "Sqrt", "float -> float", {&y}, {&z});
}

ORT_THROW_IF_ERROR(graph.Resolve());
std::string model_data;
model.ToProto().SerializeToString(&model_data);
return model_data;
}

// Single-input model with both Sin and Cos consuming `X`, used by the
// Sin/Cos tests below.
std::string MakeSinCosModelData() {
Expand All @@ -2386,6 +2428,28 @@ std::string MakeSinCosModelData() {
}
} // namespace

// On the NeuralNetwork format the Cast builder only supports a Cast that
// consumes an ArgMax, so these graph-input / Cast-fed Casts must fall back to
// CPU. Guards the IsOpSupportedImpl reordering that moved the preceding-node
// check into the NeuralNetwork branch.
TEST(CoreMLExecutionProviderTest, CastNonArgMaxNeuralNetworkNotSupported) {
const std::string model_data = MakeCastBoolModelData();
gsl::span<const std::byte> model_span{reinterpret_cast<const std::byte*>(model_data.data()),
model_data.size()};
TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::None);
}

// Load-time partition check on the ML Program path: confirms the EP claims both
// bool Casts. A non-trivial Sqrt is appended so the partition isn't dropped as
// all-trivial (see MakeCastBoolModelData); all three nodes -- both Casts and the
// Sqrt -- must land on CoreML.
TEST(CoreMLExecutionProviderTest, CastBoolMLProgramPartition) {
const std::string model_data = MakeCastBoolModelData(/*append_nontrivial=*/true);
gsl::span<const std::byte> model_span{reinterpret_cast<const std::byte*>(model_data.data()),
model_data.size()};
TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All);
}

// Sin and Cos are lowered to the ML Program 'sin' / 'cos' ops.
TEST(CoreMLExecutionProviderTest, SinCos_MLProgram) {
const std::string model_data = MakeSinCosModelData();
Expand Down