diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 61b18859ba685..8be860fd0f462 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -83,15 +83,19 @@ Status CastOpBuilder::AddToModelBuilderImpl([[maybe_unused]] ModelBuilder& model bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { + if (input_params.create_mlprogram) { + // The ML Program 'cast' op stands alone, so a Cast fed directly by a graph + // input (no preceding node) is fine here. + return true; + } + + // The NeuralNetwork path only supports a Cast that consumes an ArgMax, so it + // needs a preceding node to inspect (InputEdgesBegin() must be dereferenceable). if (node.GetInputEdgesCount() == 0) { LOGS(logger, VERBOSE) << "Cast has no preceding nodes."; return false; } - if (input_params.create_mlprogram) { - return true; - } - const auto& prec_node = node.InputEdgesBegin()->GetNode(); /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax @@ -141,11 +145,13 @@ bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, [[maybe_unused]] co if ((input_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || input_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || - input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) && + input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || + input_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL) && (output_type == ONNX_NAMESPACE::TensorProto_DataType_INT32 || output_type == ONNX_NAMESPACE::TensorProto_DataType_INT64 || output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || - output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)) { + output_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || + output_type == ONNX_NAMESPACE::TensorProto_DataType_BOOL)) { return true; } else { LOGS(logger, VERBOSE) << "[" << node.OpType() diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index b67a9bde5c63b..77f43b60dd6f8 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -2361,6 +2361,48 @@ TEST(CoreMLExecutionProviderTest, Split11SingleOutputNotSupported) { } namespace { +// int64 -> Cast(bool) -> Cast(float) [-> Sqrt]; the first Cast is fed directly +// by a graph input (no preceding node). +// +// append_nontrivial=false gives the all-Cast graph used by the NeuralNetwork +// negative test below. append_nontrivial=true appends a Sqrt: a CoreML partition +// made up only of trivial ops (Cast is marked trivial) is dropped, so the extra +// non-trivial op keeps the partition and lets the test below assert the bool +// Casts are claimed. +std::string MakeCastBoolModelData(bool append_nontrivial = false) { + onnxruntime::Model model("cast_bool_test", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + auto make_type = [](int32_t elem_type) { + ONNX_NAMESPACE::TypeProto t; + t.mutable_tensor_type()->set_elem_type(elem_type); + for (int64_t d : {1, 4}) t.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(d); + return t; + }; + const auto int64_type = make_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + const auto bool_type = make_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL); + const auto float_type = make_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + auto& x = graph.GetOrCreateNodeArg("X", &int64_type); + auto& b = graph.GetOrCreateNodeArg("B", &bool_type); + auto& y = graph.GetOrCreateNodeArg("Y", &float_type); + + auto& to_bool = graph.AddNode("cast_to_bool", "Cast", "int64 -> bool", {&x}, {&b}); + to_bool.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BOOL)); + auto& to_float = graph.AddNode("cast_to_float", "Cast", "bool -> float", {&b}, {&y}); + to_float.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + + if (append_nontrivial) { + auto& z = graph.GetOrCreateNodeArg("Z", &float_type); + graph.AddNode("sqrt", "Sqrt", "float -> float", {&y}, {&z}); + } + + ORT_THROW_IF_ERROR(graph.Resolve()); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + return model_data; +} + // Single-input model with both Sin and Cos consuming `X`, used by the // Sin/Cos tests below. std::string MakeSinCosModelData() { @@ -2386,6 +2428,28 @@ std::string MakeSinCosModelData() { } } // namespace +// On the NeuralNetwork format the Cast builder only supports a Cast that +// consumes an ArgMax, so these graph-input / Cast-fed Casts must fall back to +// CPU. Guards the IsOpSupportedImpl reordering that moved the preceding-node +// check into the NeuralNetwork branch. +TEST(CoreMLExecutionProviderTest, CastNonArgMaxNeuralNetworkNotSupported) { + const std::string model_data = MakeCastBoolModelData(); + gsl::span model_span{reinterpret_cast(model_data.data()), + model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::None); +} + +// Load-time partition check on the ML Program path: confirms the EP claims both +// bool Casts. A non-trivial Sqrt is appended so the partition isn't dropped as +// all-trivial (see MakeCastBoolModelData); all three nodes -- both Casts and the +// Sqrt -- must land on CoreML. +TEST(CoreMLExecutionProviderTest, CastBoolMLProgramPartition) { + const std::string model_data = MakeCastBoolModelData(/*append_nontrivial=*/true); + gsl::span model_span{reinterpret_cast(model_data.data()), + model_data.size()}; + TestModelLoad(model_span, MakeCoreMLExecutionProvider("MLProgram"), ExpectedEPNodeAssignment::All); +} + // Sin and Cos are lowered to the ML Program 'sin' / 'cos' ops. TEST(CoreMLExecutionProviderTest, SinCos_MLProgram) { const std::string model_data = MakeSinCosModelData();