Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,42 @@ struct OrtEp {
* \since Version 1.27.
*/
ORT_API2_STATUS(OnSessionInitializationEnd, _In_ OrtEp* this_ptr);

/** \brief Get the EP's default memory device.
*
* The EP's default memory device identifies the hardware the EP operates on. ORT uses it to:
* - Determine if data copies are needed between EPs (inserting memcpy nodes at EP boundaries)
* - Determine if the EP is CPU-based (which affects synchronization and data transfer decisions)
* - Bind execution streams to the correct device
*
* An OrtMemoryDevice is obtained from an OrtMemoryInfo via `OrtEpApi::MemoryInfo_GetMemoryDevice()`.
* Typically, an EP creates OrtMemoryInfo instances and registers them with its OrtEpDevice(s) via
* `OrtEpApi::EpDevice_AddAllocatorInfo()`. The OrtMemoryDevice returned here must correspond to an
* OrtMemoryInfo registered as an `OrtDeviceAllocator` entry (either `OrtDeviceMemoryType_DEFAULT` or
* `OrtDeviceMemoryType_HOST_ACCESSIBLE`). An OrtMemoryDevice from an `OrtReadOnlyAllocator` entry is
* not accepted as the EP's default/identity device.
*
* The returned pointer must remain valid for the lifetime of the OrtEp instance
* (typically by storing the parent OrtMemoryInfo as a member of the EP).
*
* If this function is not implemented (NULL), or if it sets `device` to NULL, ORT infers
* the default memory device from the `OrtDeviceAllocator` entry with `OrtDeviceMemoryType_DEFAULT`
* registered via `EpDevice_AddAllocatorInfo`. In this fallback case, all OrtEpDevice instances must
* use the same `OrtDeviceMemoryType_DEFAULT` OrtMemoryInfo (or ORT cannot determine which device to
* use). If no such entry is registered, the EP defaults to a CPU memory device.
*
* \param[in] this_ptr The OrtEp instance.
* \param[out] device Set to the EP's default OrtMemoryDevice, or NULL to use the default behavior (described above).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \note Implementation of this function is optional. If set to NULL (not implemented), ORT
* infers the default memory device using the default behavior described above.
*
* \since Version 1.27.
*/
Comment thread
ericcraw marked this conversation as resolved.
ORT_API2_STATUS(GetDefaultMemoryDevice, _In_ const OrtEp* this_ptr,
_Outptr_result_maybenull_ const OrtMemoryDevice** device);
};

/** \brief The function signature that ORT will call to create OrtEpFactory instances.
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -913,12 +913,7 @@ class PlannerImpl {
ProcessDef(index, node_output);
OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i));
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// Downstream nodes of certain providers may require a CPU accessible location override
// to make sure the EP does not incur an unnecessary copy.
// We only do it for CPU based EPs. We are not likely to encounter
// non CPU devices here since they are already taken care of by using MemCpy nodes earlier.
// However, we still ignore them.
if (output_device.Type() == OrtDevice::CPU) {
if (output_device.UsesCpuMemory()) {
Comment thread
ericcraw marked this conversation as resolved.
const auto& output_name = node_output->Name();
const auto consumers = graph_viewer_.GetConsumerNodes(output_name);
for (const auto* consumer : consumers) {
Expand Down
107 changes: 76 additions & 31 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,53 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider) {
return provider.GetDevice().Type() == OrtDevice::CPU;
}

// Returns true if src memory can satisfy tgt's requirements without a data copy.
//
// HOST_ACCESSIBLE → DEFAULT is valid: the device can access HOST_ACCESSIBLE memory directly.
// DEFAULT → HOST_ACCESSIBLE is NOT valid: HOST_ACCESSIBLE implies CPU consumers, and DEFAULT
// memory is device-only — the CPU cannot read it.
//
// For the mixed case, src alignment must meet tgt's minimum requirement.
// Alignment 0 on tgt means "no alignment requirement". Alignment 0 on src means "unknown" and
// does not satisfy a non-zero tgt alignment requirement.
bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt) {
const bool src_is_cpu_mem = src.UsesCpuMemory();
const bool tgt_is_cpu_mem = tgt.UsesCpuMemory();

// Identical devices are always compatible.
if (src == tgt) {
return true;
}

// Alignment 0 means "unspecified" — treat tgt as compatible with any alignment requirement.
const bool is_alignment_satisfied = tgt.GetAlignment() == 0 ||
src.GetAlignment() >= tgt.GetAlignment();

const bool is_same_physical_device = src.Type() == tgt.Type() &&
src.Vendor() == tgt.Vendor() &&
src.Id() == tgt.Id();

// Both are CPU-accessible (CPU type or HOST_ACCESSIBLE memory).
if (src_is_cpu_mem && tgt_is_cpu_mem) {
// CPU target can read from any CPU or HOST_ACCESSIBLE source, regardless of the source device
if (tgt.Type() == OrtDevice::CPU) {
return is_alignment_satisfied;
}
Comment thread
ericcraw marked this conversation as resolved.
// Both are HOST_ACCESSIBLE on some device: require the same physical device.
return is_same_physical_device && is_alignment_satisfied;
}

// HOST_ACCESSIBLE source can serve a DEFAULT target on the same physical device —
// the device can DMA from HOST_ACCESSIBLE memory directly.
// The reverse (DEFAULT → HOST_ACCESSIBLE) is unsafe: HOST_ACCESSIBLE implies CPU consumers,
// and DEFAULT memory is device-only so the CPU cannot read it.
if (src_is_cpu_mem && !tgt_is_cpu_mem) {
return is_same_physical_device && is_alignment_satisfied;
}

return false;
}

bool IsMemcpyNode(const Node& node) {
return node.Domain() == kOnnxDomain &&
(node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost");
Expand Down Expand Up @@ -117,6 +164,28 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info)
return required_provider_type;
}

// Populate device_fetches for the output-copy path.
// When the user pre-allocates a fetch buffer, reuse it directly as the EP's output buffer if
// the user's buffer (tgt) can satisfy the EP's output device (src) requirements — i.e.,
// CanSourceSatisfyTarget(tgt, src). This avoids a post-execution copy.
// Otherwise inserts an empty placeholder for the EP to allocate into.
static void PopulateDeviceFetches(gsl::span<const MLValueCopyInfo> fetch_copy_info,
Comment thread
ericcraw marked this conversation as resolved.
const std::vector<OrtValue>& fetches,
std::vector<OrtValue>& device_fetches) {
ORT_ENFORCE(fetch_copy_info.size() >= fetches.size());
Comment thread
ericcraw marked this conversation as resolved.
device_fetches.clear();
device_fetches.reserve(fetches.size());
for (size_t i = 0; i < fetches.size(); ++i) {
const auto& src = fetch_copy_info[i].source_device;
const auto& tgt = fetch_copy_info[i].target_device;
if (CanSourceSatisfyTarget(tgt, src) && fetches[i].IsAllocated()) {
Comment thread
ericcraw marked this conversation as resolved.
device_fetches.push_back(fetches[i]);
} else {
device_fetches.push_back({});
}
}
}

// Copy MLValue. Uses DataTransferManager for device copy if necessary. If copy_tensor_pairs/copy_sparse_pairs is provided,
// src/dst pairs that need a device copy are added to copy_pairs so copying can be batches by the DataTransferManager
// implementation for performance reasons.
Expand All @@ -132,8 +201,9 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
std::vector<IDataTransfer::SrcDstPair>* copy_tensor_pairs = nullptr)
#endif
{
// same device so direct copy
if (copy_info.source_device == copy_info.target_device) {
// No data transfer needed if devices are identical, or the source can satisfy the target
// (HOST_ACCESSIBLE source serving a DEFAULT target on the same physical device).
if (CanSourceSatisfyTarget(copy_info.source_device, copy_info.target_device)) {
target_mlvalue = source_mlvalue;
return Status::OK();
}
Expand Down Expand Up @@ -324,7 +394,7 @@ static bool FinalizeCopyInfoForFeeds(gsl::span<const OrtDevice> feed_locations,
for (size_t i = 0, end = feed_locations.size(); i < end; ++i) {
copy_info[i].source_device = feed_locations[i];

if (copy_info[i].source_device != copy_info[i].target_device) {
if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) {
copy_needed = true;
}
}
Expand All @@ -345,7 +415,7 @@ static bool FinalizeCopyInfoForFetches(gsl::span<const OrtDevice* const>& fetch_
copy_info[i].target_device = *alloc_info;
}

if (copy_info[i].source_device != copy_info[i].target_device) {
if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) {
copy_needed = true;
}
}
Expand Down Expand Up @@ -652,22 +722,9 @@ ExecuteGraphImpl(const SessionState& session_state,
feeds_to_use = device_feeds;
}

auto num_outputs = fetches.size();
const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo();

if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
// need intermediate fetches. use pre-allocated fetches where possible.
device_fetches.reserve(num_outputs);

for (size_t i = 0; i < num_outputs; ++i) {
if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) {
device_fetches.push_back(fetches[i]);
} else {
// use temporary value
device_fetches.push_back({});
}
}

PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches);
p_fetches = &device_fetches;
}

Expand Down Expand Up @@ -808,22 +865,10 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF
p_feeds = device_feeds;
}

auto num_outputs = fetches.size();
const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo();

if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
// need intermediate fetches. use pre-allocated fetches where possible.
device_fetches.reserve(num_outputs);

for (size_t i = 0; i < num_outputs; ++i) {
if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) {
device_fetches.push_back(fetches[i]);
} else {
// use temporary value
device_fetches.push_back({});
}
}

PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches);
p_fetches = &device_fetches;
}

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider);

bool IsMemcpyNode(const Node& node);

// Returns true if src memory can satisfy tgt's requirements without a data copy.
// HOST_ACCESSIBLE -> DEFAULT is valid (device can access HOST_ACCESSIBLE memory directly).
// DEFAULT -> HOST_ACCESSIBLE is NOT valid (CPU cannot read device-only memory).
bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt);

common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,21 @@ struct PluginEpMetaDefNameFunctor {
// PluginExecutionProvider
//

static OrtDevice GetOrtDeviceForPluginEp(gsl::span<const OrtEpDevice* const> ep_devices) {
// Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU.
// If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all.
static OrtDevice GetOrtDeviceForPluginEp(const OrtEp& ep, gsl::span<const OrtEpDevice* const> ep_devices) {
// Resolve the EP's default device. If the EP implements GetDefaultMemoryDevice, use its
// answer directly. Otherwise infer from OrtEpDevice.device_memory_info (and enforce that
// all OrtEpDevice instances agree).

ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices.

if (ep.ort_version_supported >= 27 && ep.GetDefaultMemoryDevice != nullptr) {
Comment thread
ericcraw marked this conversation as resolved.
const OrtMemoryDevice* memory_device = nullptr;
Ort::ThrowOnError(ep.GetDefaultMemoryDevice(&ep, &memory_device));
if (memory_device != nullptr) {
return *static_cast<const OrtDevice*>(memory_device);
}
}

const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info;

// Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos
Expand Down Expand Up @@ -169,7 +178,7 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio
gsl::span<const OrtEpDevice* const> ep_devices,
std::shared_ptr<KernelRegistry> kernel_registry,
const logging::Logger& logger)
: IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices),
: IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(*ep, ep_devices),
std::vector<const OrtEpDevice*>(ep_devices.begin(), ep_devices.end()), logger),
ort_ep_(std::move(ep)),
Comment thread
ericcraw marked this conversation as resolved.
ep_factory_(ep_factory),
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models
Sync = SyncImpl; // optional. can be nullptr
GetDefaultMemoryDevice = GetDefaultMemoryDeviceImpl; // optional. can be nullptr

IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
Expand Down Expand Up @@ -602,6 +603,14 @@ OrtStatus* ORT_API_CALL ExampleEp::SyncImpl(_In_ OrtEp* this_ptr) noexcept {
return nullptr;
}

/*static*/
OrtStatus* ORT_API_CALL ExampleEp::GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr,
_Outptr_ const OrtMemoryDevice** device) noexcept {
const auto* ep = static_cast<const ExampleEp*>(this_ptr);
*device = ep->ep_api.MemoryInfo_GetMemoryDevice(ep->factory_.GetDefaultMemoryInfo());
return nullptr;
}

//
// Implementation of ExampleNodeComputeInfo
//
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class ExampleEp : public OrtEp, public ApiPtrs {

static OrtStatus* ORT_API_CALL SyncImpl(_In_ OrtEp* this_ptr) noexcept;

static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr,
_Outptr_ const OrtMemoryDevice** device) noexcept;

OrtStatus* CreateEpContextNodes(gsl::span<const OrtNode*> fused_nodes,
/*out*/ gsl::span<OrtNode*> ep_context_nodes);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs {
return vendor_id_;
}

const OrtMemoryInfo* GetDefaultMemoryInfo() const {
return default_memory_info_;
}

const OrtLogger& default_logger_; // default logger for the EP factory

private:
Expand Down
Loading