diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 76fb7ce93b600..811502b78c27b 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2567,6 +2567,42 @@ struct OrtEp { * \since Version 1.27. */ ORT_API2_STATUS(OnSessionInitializationEnd, _In_ OrtEp* this_ptr); + + /** \brief Get the EP's default memory device. + * + * The EP's default memory device identifies the hardware the EP operates on. ORT uses it to: + * - Determine if data copies are needed between EPs (inserting memcpy nodes at EP boundaries) + * - Determine if the EP is CPU-based (which affects synchronization and data transfer decisions) + * - Bind execution streams to the correct device + * + * An OrtMemoryDevice is obtained from an OrtMemoryInfo via `OrtEpApi::MemoryInfo_GetMemoryDevice()`. + * Typically, an EP creates OrtMemoryInfo instances and registers them with its OrtEpDevice(s) via + * `OrtEpApi::EpDevice_AddAllocatorInfo()`. The OrtMemoryDevice returned here must correspond to an + * OrtMemoryInfo registered as an `OrtDeviceAllocator` entry (either `OrtDeviceMemoryType_DEFAULT` or + * `OrtDeviceMemoryType_HOST_ACCESSIBLE`). An OrtMemoryDevice from an `OrtReadOnlyAllocator` entry is + * not accepted as the EP's default/identity device. + * + * The returned pointer must remain valid for the lifetime of the OrtEp instance + * (typically by storing the parent OrtMemoryInfo as a member of the EP). + * + * If this function is not implemented (NULL), or if it sets `device` to NULL, ORT infers + * the default memory device from the `OrtDeviceAllocator` entry with `OrtDeviceMemoryType_DEFAULT` + * registered via `EpDevice_AddAllocatorInfo`. In this fallback case, all OrtEpDevice instances must + * use the same `OrtDeviceMemoryType_DEFAULT` OrtMemoryInfo (or ORT cannot determine which device to + * use). If no such entry is registered, the EP defaults to a CPU memory device. + * + * \param[in] this_ptr The OrtEp instance. + * \param[out] device Set to the EP's default OrtMemoryDevice, or NULL to use the default behavior (described above). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. If set to NULL (not implemented), ORT + * infers the default memory device using the default behavior described above. + * + * \since Version 1.27. + */ + ORT_API2_STATUS(GetDefaultMemoryDevice, _In_ const OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtMemoryDevice** device); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 1c80d83f99feb..006e6e0da8b56 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -913,12 +913,7 @@ class PlannerImpl { ProcessDef(index, node_output); OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Downstream nodes of certain providers may require a CPU accessible location override - // to make sure the EP does not incur an unnecessary copy. - // We only do it for CPU based EPs. We are not likely to encounter - // non CPU devices here since they are already taken care of by using MemCpy nodes earlier. - // However, we still ignore them. - if (output_device.Type() == OrtDevice::CPU) { + if (output_device.UsesCpuMemory()) { const auto& output_name = node_output->Name(); const auto consumers = graph_viewer_.GetConsumerNodes(output_name); for (const auto* consumer : consumers) { diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index f1945ded10b07..f375bc134ea66 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -50,6 +50,53 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider) { return provider.GetDevice().Type() == OrtDevice::CPU; } +// Returns true if src memory can satisfy tgt's requirements without a data copy. +// +// HOST_ACCESSIBLE → DEFAULT is valid: the device can access HOST_ACCESSIBLE memory directly. +// DEFAULT → HOST_ACCESSIBLE is NOT valid: HOST_ACCESSIBLE implies CPU consumers, and DEFAULT +// memory is device-only — the CPU cannot read it. +// +// For the mixed case, src alignment must meet tgt's minimum requirement. +// Alignment 0 on tgt means "no alignment requirement". Alignment 0 on src means "unknown" and +// does not satisfy a non-zero tgt alignment requirement. +bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt) { + const bool src_is_cpu_mem = src.UsesCpuMemory(); + const bool tgt_is_cpu_mem = tgt.UsesCpuMemory(); + + // Identical devices are always compatible. + if (src == tgt) { + return true; + } + + // Alignment 0 means "unspecified" — treat tgt as compatible with any alignment requirement. + const bool is_alignment_satisfied = tgt.GetAlignment() == 0 || + src.GetAlignment() >= tgt.GetAlignment(); + + const bool is_same_physical_device = src.Type() == tgt.Type() && + src.Vendor() == tgt.Vendor() && + src.Id() == tgt.Id(); + + // Both are CPU-accessible (CPU type or HOST_ACCESSIBLE memory). + if (src_is_cpu_mem && tgt_is_cpu_mem) { + // CPU target can read from any CPU or HOST_ACCESSIBLE source, regardless of the source device + if (tgt.Type() == OrtDevice::CPU) { + return is_alignment_satisfied; + } + // Both are HOST_ACCESSIBLE on some device: require the same physical device. + return is_same_physical_device && is_alignment_satisfied; + } + + // HOST_ACCESSIBLE source can serve a DEFAULT target on the same physical device — + // the device can DMA from HOST_ACCESSIBLE memory directly. + // The reverse (DEFAULT → HOST_ACCESSIBLE) is unsafe: HOST_ACCESSIBLE implies CPU consumers, + // and DEFAULT memory is device-only so the CPU cannot read it. + if (src_is_cpu_mem && !tgt_is_cpu_mem) { + return is_same_physical_device && is_alignment_satisfied; + } + + return false; +} + bool IsMemcpyNode(const Node& node) { return node.Domain() == kOnnxDomain && (node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost"); @@ -117,6 +164,28 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info) return required_provider_type; } +// Populate device_fetches for the output-copy path. +// When the user pre-allocates a fetch buffer, reuse it directly as the EP's output buffer if +// the user's buffer (tgt) can satisfy the EP's output device (src) requirements — i.e., +// CanSourceSatisfyTarget(tgt, src). This avoids a post-execution copy. +// Otherwise inserts an empty placeholder for the EP to allocate into. +static void PopulateDeviceFetches(gsl::span fetch_copy_info, + const std::vector& fetches, + std::vector& device_fetches) { + ORT_ENFORCE(fetch_copy_info.size() >= fetches.size()); + device_fetches.clear(); + device_fetches.reserve(fetches.size()); + for (size_t i = 0; i < fetches.size(); ++i) { + const auto& src = fetch_copy_info[i].source_device; + const auto& tgt = fetch_copy_info[i].target_device; + if (CanSourceSatisfyTarget(tgt, src) && fetches[i].IsAllocated()) { + device_fetches.push_back(fetches[i]); + } else { + device_fetches.push_back({}); + } + } +} + // Copy MLValue. Uses DataTransferManager for device copy if necessary. If copy_tensor_pairs/copy_sparse_pairs is provided, // src/dst pairs that need a device copy are added to copy_pairs so copying can be batches by the DataTransferManager // implementation for performance reasons. @@ -132,8 +201,9 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, std::vector* copy_tensor_pairs = nullptr) #endif { - // same device so direct copy - if (copy_info.source_device == copy_info.target_device) { + // No data transfer needed if devices are identical, or the source can satisfy the target + // (HOST_ACCESSIBLE source serving a DEFAULT target on the same physical device). + if (CanSourceSatisfyTarget(copy_info.source_device, copy_info.target_device)) { target_mlvalue = source_mlvalue; return Status::OK(); } @@ -324,7 +394,7 @@ static bool FinalizeCopyInfoForFeeds(gsl::span feed_locations, for (size_t i = 0, end = feed_locations.size(); i < end; ++i) { copy_info[i].source_device = feed_locations[i]; - if (copy_info[i].source_device != copy_info[i].target_device) { + if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) { copy_needed = true; } } @@ -345,7 +415,7 @@ static bool FinalizeCopyInfoForFetches(gsl::span& fetch_ copy_info[i].target_device = *alloc_info; } - if (copy_info[i].source_device != copy_info[i].target_device) { + if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) { copy_needed = true; } } @@ -652,22 +722,9 @@ ExecuteGraphImpl(const SessionState& session_state, feeds_to_use = device_feeds; } - auto num_outputs = fetches.size(); const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo(); - if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) { - // need intermediate fetches. use pre-allocated fetches where possible. - device_fetches.reserve(num_outputs); - - for (size_t i = 0; i < num_outputs; ++i) { - if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) { - device_fetches.push_back(fetches[i]); - } else { - // use temporary value - device_fetches.push_back({}); - } - } - + PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches); p_fetches = &device_fetches; } @@ -808,22 +865,10 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF p_feeds = device_feeds; } - auto num_outputs = fetches.size(); const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo(); if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) { - // need intermediate fetches. use pre-allocated fetches where possible. - device_fetches.reserve(num_outputs); - - for (size_t i = 0; i < num_outputs; ++i) { - if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) { - device_fetches.push_back(fetches[i]); - } else { - // use temporary value - device_fetches.push_back({}); - } - } - + PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches); p_fetches = &device_fetches; } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 9c637d4bac81f..20ef81d12f1c8 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -57,6 +57,11 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider); bool IsMemcpyNode(const Node& node); +// Returns true if src memory can satisfy tgt's requirements without a data copy. +// HOST_ACCESSIBLE -> DEFAULT is valid (device can access HOST_ACCESSIBLE memory directly). +// DEFAULT -> HOST_ACCESSIBLE is NOT valid (CPU cannot read device-only memory). +bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt); + common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name, const OrtValue& orig_mlvalue, OrtValue& new_mlvalue); diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index d8094fe68ea53..f7691ed5dc02a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -121,12 +121,21 @@ struct PluginEpMetaDefNameFunctor { // PluginExecutionProvider // -static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices) { - // Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU. - // If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all. +static OrtDevice GetOrtDeviceForPluginEp(const OrtEp& ep, gsl::span ep_devices) { + // Resolve the EP's default device. If the EP implements GetDefaultMemoryDevice, use its + // answer directly. Otherwise infer from OrtEpDevice.device_memory_info (and enforce that + // all OrtEpDevice instances agree). ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices. + if (ep.ort_version_supported >= 27 && ep.GetDefaultMemoryDevice != nullptr) { + const OrtMemoryDevice* memory_device = nullptr; + Ort::ThrowOnError(ep.GetDefaultMemoryDevice(&ep, &memory_device)); + if (memory_device != nullptr) { + return *static_cast(memory_device); + } + } + const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info; // Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos @@ -169,7 +178,7 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio gsl::span ep_devices, std::shared_ptr kernel_registry, const logging::Logger& logger) - : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), + : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(*ep, ep_devices), std::vector(ep_devices.begin(), ep_devices.end()), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index ca9a296501b04..fecf7ac9a4038 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -185,6 +185,7 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models Sync = SyncImpl; // optional. can be nullptr + GetDefaultMemoryDevice = GetDefaultMemoryDeviceImpl; // optional. can be nullptr IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -602,6 +603,14 @@ OrtStatus* ORT_API_CALL ExampleEp::SyncImpl(_In_ OrtEp* this_ptr) noexcept { return nullptr; } +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr, + _Outptr_ const OrtMemoryDevice** device) noexcept { + const auto* ep = static_cast(this_ptr); + *device = ep->ep_api.MemoryInfo_GetMemoryDevice(ep->factory_.GetDefaultMemoryInfo()); + return nullptr; +} + // // Implementation of ExampleNodeComputeInfo // diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 2ba13658c3364..5dcd9f07bef1f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -105,6 +105,9 @@ class ExampleEp : public OrtEp, public ApiPtrs { static OrtStatus* ORT_API_CALL SyncImpl(_In_ OrtEp* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr, + _Outptr_ const OrtMemoryDevice** device) noexcept; + OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 91478047afb0a..4bb23f1bddace 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -38,6 +38,10 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return vendor_id_; } + const OrtMemoryInfo* GetDefaultMemoryInfo() const { + return default_memory_info_; + } + const OrtLogger& default_logger_; // default logger for the EP factory private: diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 80b638314bad9..0908527721824 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -4,8 +4,10 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include +#include #include #include +#include #include #include "gsl/gsl" #include "gtest/gtest.h" @@ -73,6 +75,18 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } + + // OrtMemoryDevice returned by GetDefaultMemoryDeviceImpl. nullptr means "defer to ORT". + const OrtMemoryDevice* test_default_memory_device = nullptr; + mutable std::atomic get_default_memory_device_call_count{0}; + + static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(const OrtEp* this_ptr, + const OrtMemoryDevice** device) noexcept { + const auto* test_ep = static_cast(this_ptr); + test_ep->get_default_memory_device_call_count.fetch_add(1, std::memory_order_relaxed); + *device = test_ep->test_default_memory_device; + return nullptr; + } }; // This factory doesn't do anything other than implement ReleaseEp(). @@ -123,13 +137,20 @@ struct MakeTestOrtEpResult { // Creates an IExecutionProvider that wraps a TestOrtEp. // The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. -MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}) { +// `setup` runs on the raw TestOrtEp before PluginExecutionProvider is constructed -- +// callbacks consulted at construction time (e.g., GetDefaultMemoryDevice seeding +// default_device_) must be configured here. +MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}, + std::function setup = nullptr) { // Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices. static std::unique_ptr ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); static std::unique_ptr ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get()); auto ort_ep_raw = std::make_unique().release(); auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); + if (setup) { + setup(*ort_ep_raw); + } auto ort_session_options = Ort::SessionOptions{}; if (ep_devices.empty()) { @@ -365,6 +386,96 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { #endif // !defined(ORT_NO_EXCEPTIONS) } +// When the EP implements GetDefaultMemoryDevice, the result seeds default_device_ at +// construction and is returned by GetOrtDeviceByMemType(OrtMemTypeDefault) at runtime. +TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_SeedsDefaultDevice) { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); + + // ep_device intentionally has no device_memory_info -- the legacy path would yield + // OrtDevice() (plain CPU). The callback must override that. + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetDefaultMemoryDevice = test_plugin_ep::TestOrtEp::GetDefaultMemoryDeviceImpl; + test_ep.test_default_memory_device = static_cast(&ort_device); + }); + + ASSERT_EQ(ep->GetDevice(), ort_device); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + ASSERT_GE(ort_ep->get_default_memory_device_call_count.load(), 1); +} + +// Version gate: ort_version_supported < 27 must bypass the callback. Without this guard +// ORT would call into a function pointer the EP didn't claim to support. +TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_VersionGateBypassesCallback) { + auto callback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + + auto fallback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); + auto fallback_mem_info = std::make_unique("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator, + fallback_device, OrtMemTypeDefault); + + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), fallback_mem_info.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.ort_version_supported = 26; // older than the GetDefaultMemoryDevice API version + test_ep.GetDefaultMemoryDevice = test_plugin_ep::TestOrtEp::GetDefaultMemoryDeviceImpl; + test_ep.test_default_memory_device = static_cast(&callback_device); + }); + + ASSERT_EQ(ep->GetDevice(), fallback_device); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), fallback_device); + ASSERT_EQ(ort_ep->get_default_memory_device_call_count.load(), 0); +} + +// Heterogeneous ep_devices: GetOrtDeviceForPluginEp throws when ep_devices have +// inconsistent device_memory_info. The callback unblocks that case by letting the EP +// name a representative device directly. +TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_HeterogeneousEpDevicesUnblocked) { + auto gpu_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + auto gpu_mem_info = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, + gpu_device, OrtMemTypeDefault); + auto npu_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); + auto npu_mem_info = std::make_unique("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator, + npu_device, OrtMemTypeDefault); + + auto representative_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, + OrtDevice::MemType::HOST_ACCESSIBLE); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), gpu_mem_info.get()); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), npu_mem_info.get()); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetDefaultMemoryDevice = test_plugin_ep::TestOrtEp::GetDefaultMemoryDeviceImpl; + test_ep.test_default_memory_device = static_cast(&representative_device); + }); + + ASSERT_EQ(ep->GetDevice(), representative_device); +} + +#if !defined(ORT_NO_EXCEPTIONS) +// A non-OK status from the callback must propagate out of PluginExecutionProvider +// construction via Ort::ThrowOnError. +TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_StatusErrorThrows) { + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); + std::vector ep_devices{ort_ep_device.get()}; + + ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetDefaultMemoryDevice = [](const OrtEp* /*this_ptr*/, const OrtMemoryDevice** /*device*/) noexcept { + return Ort::Status("injected failure", ORT_FAIL).release(); + }; + }), + Ort::Exception); +} +#endif // !defined(ORT_NO_EXCEPTIONS) + static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path, const char* ep_name, const std::unordered_set& ep_node_names, diff --git a/onnxruntime/test/framework/utils_test.cc b/onnxruntime/test/framework/utils_test.cc new file mode 100644 index 0000000000000..2b2f2ec9376ba --- /dev/null +++ b/onnxruntime/test/framework/utils_test.cc @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "core/framework/utils.h" + +namespace onnxruntime { +namespace test { + +constexpr OrtDevice::VendorId kTestVendor1 = 0x1234; +constexpr OrtDevice::VendorId kTestVendor2 = 0x5678; + +static OrtDevice Cpu() { + return OrtDevice{OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0}; +} + +static OrtDevice HostAccessible(OrtDevice::VendorId vendor, OrtDevice::DeviceId id, + OrtDevice::Alignment align = 0) { + return OrtDevice{OrtDevice::NPU, OrtDevice::MemType::HOST_ACCESSIBLE, vendor, id, align}; +} + +static OrtDevice Default(OrtDevice::VendorId vendor, OrtDevice::DeviceId id, + OrtDevice::Alignment align = 0) { + return OrtDevice{OrtDevice::NPU, OrtDevice::MemType::DEFAULT, vendor, id, align}; +} + +TEST(CanSourceSatisfyTargetTest, CpuSourceHostAccessibleTarget) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget(Cpu(), HostAccessible(kTestVendor1, 0))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleSourceCpuTarget) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget(HostAccessible(kTestVendor1, 0), Cpu())); +} + +// src == tgt early return: identical devices are always compatible +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDevice) { + auto dev = HostAccessible(kTestVendor1, 0, 16); + EXPECT_TRUE(utils::CanSourceSatisfyTarget(dev, dev)); +} + +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentId) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), HostAccessible(kTestVendor1, 1))); +} + +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentVendor) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), HostAccessible(kTestVendor2, 0))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultSameDevice) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), Default(kTestVendor1, 0))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultAlignmentSatisfied) { + // src alignment >= tgt alignment: compatible + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 64), Default(kTestVendor1, 0, 32))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultAlignmentInsufficient) { + // src alignment < tgt alignment: incompatible + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 16), Default(kTestVendor1, 0, 64))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultSrcAlignmentZero) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 0), Default(kTestVendor1, 0, 64))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultTgtAlignmentZero) { + // 0 = unspecified, treated as wildcard + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 16), Default(kTestVendor1, 0, 0))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultDifferentDeviceId) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), Default(kTestVendor1, 1))); +} + +TEST(CanSourceSatisfyTargetTest, DefaultToHostAccessibleRejected) { + // Reversed direction: CPU cannot read DEFAULT (device-only) memory + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + Default(kTestVendor1, 0), HostAccessible(kTestVendor1, 0))); +} + +TEST(CanSourceSatisfyTargetTest, DefaultToDefaultRejected) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + Default(kTestVendor1, 0), Default(kTestVendor2, 0))); +} + +// Early return: identical CPU devices are always compatible. +TEST(CanSourceSatisfyTargetTest, CpuToCpuIdentical) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget(Cpu(), Cpu())); +} + +// Early return: identical DEFAULT devices on the same physical device are compatible. +TEST(CanSourceSatisfyTargetTest, DefaultToDefaultSameDevice) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget(Default(kTestVendor1, 0), Default(kTestVendor1, 0))); +} + +// Both HOST_ACCESSIBLE, same physical device — alignment variations (not the early-return path +// because src and tgt differ in alignment, so src != tgt). +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDeviceAlignmentSatisfied) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 64), HostAccessible(kTestVendor1, 0, 32))); +} + +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDeviceAlignmentInsufficient) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 16), HostAccessible(kTestVendor1, 0, 32))); +} + +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDeviceSrcAlignmentZero) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 0), HostAccessible(kTestVendor1, 0, 32))); +} + +// HOST_ACCESSIBLE → DEFAULT: same device id but different vendor fails is_same_physical_device. +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultDifferentVendor) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), Default(kTestVendor2, 0))); +} + +} // namespace test +} // namespace onnxruntime