diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 3e1b87821fe2f..08cc9455fd34b 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -9,14 +9,18 @@ namespace onnxruntime { namespace webgpu { GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator) + : GpuBufferAllocator([buffer_manager_ptr = &buffer_manager]() -> const BufferManager& { return *buffer_manager_ptr; }, is_read_only_allocator) { +} + +GpuBufferAllocator::GpuBufferAllocator(std::function buffer_manager_getter, bool is_read_only_allocator) : IAllocator( OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, WebGpuDevice, OrtMemTypeDefault)), - buffer_manager_{buffer_manager}, - mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { + buffer_manager_getter_{std::move(buffer_manager_getter)}, + mapped_at_creation_{is_read_only_allocator && buffer_manager_getter_().SupportsUMA()} { } void* GpuBufferAllocator::Alloc(size_t size) { @@ -26,15 +30,17 @@ void* GpuBufferAllocator::Alloc(size_t size) { stats_.num_allocs++; + const auto& buffer_manager = buffer_manager_getter_(); + wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite : wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect; - return buffer_manager_.Create(size, usage); + return buffer_manager.Create(size, usage); } void GpuBufferAllocator::Free(void* p) { if (p != nullptr) { - buffer_manager_.Release(static_cast(p)); + buffer_manager_getter_().Release(static_cast(p)); stats_.num_allocs--; } } diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 74b3d669fcf3b..fadfc8c86cfc4 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/framework/allocator.h" #include "core/framework/ortdevice.h" @@ -19,6 +21,7 @@ inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); + GpuBufferAllocator(std::function buffer_manager_getter, bool is_read_only_allocator); virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; @@ -26,7 +29,7 @@ class GpuBufferAllocator : public IAllocator { private: AllocatorStats stats_; - const BufferManager& buffer_manager_; + std::function buffer_manager_getter_; bool mapped_at_creation_; }; diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 632e04a36c7bf..e12b42d931c2f 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -34,10 +34,12 @@ class ComputeContextBase { public: // Nested accessor class to provide controlled access to BufferManager class BufferManagerAccessor { - // access to BufferManager is limited to class WebGpuContext. - // This ensures no access to BufferManager from other classes, avoiding - // potential misuse. + // Access to BufferManager is limited to WebGpuContext and ComputeContextBase. + // ComputeContextBase needs it for FlushAndWait(), which routes through the + // currently-active buffer manager. This narrow allow-list prevents + // arbitrary classes from reaching into BufferManager directly. friend class WebGpuContext; + friend class ComputeContextBase; private: static const webgpu::BufferManager& Get(const ComputeContextBase& context); @@ -121,6 +123,11 @@ class ComputeContextBase { return webgpu_context_.Run(*this, program); } + inline Status FlushAndWait() { + webgpu_context_.Flush(BufferManagerAccessor::Get(*this)); + return webgpu_context_.WaitForQueueIdle(); + } + protected: WebGpuContext& webgpu_context_; const WebGpuExecutionProvider& ep_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index ada9a2e8ab692..58e71de1fa211 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -184,6 +184,16 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } +Status WebGpuContext::WaitForQueueIdle() { + return Wait(device_queue_.OnSubmittedWorkDone( + wgpu::CallbackMode::WaitAnyOnly, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + ORT_ENFORCE(status == wgpu::QueueWorkDoneStatus::Success, + "Failed to wait for submitted WebGPU work: ", + std::string_view{message}); + })); +} + Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 021c7f383a6d7..fb6da131e45fe 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -165,6 +165,7 @@ class WebGpuContextFactory { class WebGpuContext final { public: Status Wait(wgpu::Future f); + Status WaitForQueueIdle(); const wgpu::Device& Device() const { return device_; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index d1cde04277938..ff1270938b639 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -578,16 +578,6 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, enable_int64_{config.enable_graph_capture || config.enable_int64}, multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, prepack_allocator_{std::make_shared(context_.InitializerBufferManager(), false)} { - // If graph capture is enabled, create a dedicated buffer manager for graph mode - if (enable_graph_capture_) { - // Create buffer manager for graph capture mode with appropriate cache modes - graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( - context_, - webgpu::BufferCacheMode::Graph, - webgpu::BufferCacheMode::GraphSimple, - webgpu::BufferCacheMode::Disabled); - } - if (config.enable_pix_capture) { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator @@ -603,7 +593,7 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { // allocator for initializers std::make_unique(context_.InitializerBufferManager(), true), // default allocator - std::make_unique(BufferManager(), false), + std::make_unique([this]() -> const webgpu::BufferManager& { return BufferManager(); }, false), }; } @@ -773,6 +763,14 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op } if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { + if (!graph_buffer_mgr_) { + graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( + context_, + webgpu::BufferCacheMode::Graph, + webgpu::BufferCacheMode::GraphSimple, + webgpu::BufferCacheMode::Disabled); + } + graph_buffer_mgr_active_ = true; context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); } m_current_graph_annotation_id = graph_annotation_id; @@ -794,6 +792,8 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti } } + graph_buffer_mgr_active_ = false; + if (session_profiler_ && session_profiler_->Enabled()) { // Session-level profiling: collect into profiler's own events storage. context_.CollectProfilingData(session_profiler_->GpuEvents()); @@ -825,6 +825,7 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); + ORT_ENFORCE(graph_buffer_mgr_ != nullptr, "Graph buffer manager must exist before replay."); // TODO: enable profiling in run level if (session_profiler_ && session_profiler_->Enabled()) { context_.StartProfiling(); @@ -838,7 +839,7 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { } webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const { - if (graph_buffer_mgr_) { + if (graph_buffer_mgr_active_ && graph_buffer_mgr_) { return *graph_buffer_mgr_; } else { return context_.BufferManager(); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index d1e2231dbba6f..69171df6a8a45 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -127,6 +127,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool enable_int64_ = false; uint32_t multi_rotary_cache_concat_offset_ = 0; bool is_graph_captured_ = false; + bool graph_buffer_mgr_active_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // Required regular runs before graph capture for any necessary allocations. int m_current_graph_annotation_id = 0;