Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions onnxruntime/core/providers/webgpu/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,18 @@
namespace webgpu {

GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator)
: GpuBufferAllocator([buffer_manager_ptr = &buffer_manager]() -> const BufferManager& { return *buffer_manager_ptr; }, is_read_only_allocator) {
}

GpuBufferAllocator::GpuBufferAllocator(std::function<const BufferManager&()> buffer_manager_getter, bool is_read_only_allocator)
: IAllocator(
OrtMemoryInfo(WEBGPU_BUFFER,
is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator
: OrtAllocatorType::OrtDeviceAllocator,
WebGpuDevice,
OrtMemTypeDefault)),
buffer_manager_{buffer_manager},
mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} {
buffer_manager_getter_{std::move(buffer_manager_getter)},

Check warning on line 22 in onnxruntime/core/providers/webgpu/allocator.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/allocator.cc:22: Add #include <utility> for move [build/include_what_you_use] [4]
mapped_at_creation_{is_read_only_allocator && buffer_manager_getter_().SupportsUMA()} {
}

void* GpuBufferAllocator::Alloc(size_t size) {
Expand All @@ -26,15 +30,17 @@

stats_.num_allocs++;

const auto& buffer_manager = buffer_manager_getter_();

wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite
: wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect;

return buffer_manager_.Create(size, usage);
return buffer_manager.Create(size, usage);
}

void GpuBufferAllocator::Free(void* p) {
if (p != nullptr) {
buffer_manager_.Release(static_cast<WGPUBuffer>(p));
buffer_manager_getter_().Release(static_cast<WGPUBuffer>(p));
stats_.num_allocs--;
}
}
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/providers/webgpu/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <functional>

#include "core/framework/allocator.h"
#include "core/framework/ortdevice.h"

Expand All @@ -19,14 +21,15 @@ inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU,
class GpuBufferAllocator : public IAllocator {
public:
GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator);
GpuBufferAllocator(std::function<const BufferManager&()> buffer_manager_getter, bool is_read_only_allocator);

virtual void* Alloc(size_t size) override;
virtual void Free(void* p) override;
void GetStats(AllocatorStats* stats) override;

private:
AllocatorStats stats_;
const BufferManager& buffer_manager_;
std::function<const BufferManager&()> buffer_manager_getter_;
bool mapped_at_creation_;
};

Expand Down
13 changes: 10 additions & 3 deletions onnxruntime/core/providers/webgpu/compute_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ class ComputeContextBase {
public:
// Nested accessor class to provide controlled access to BufferManager
class BufferManagerAccessor {
// access to BufferManager is limited to class WebGpuContext.
// This ensures no access to BufferManager from other classes, avoiding
// potential misuse.
// Access to BufferManager is limited to WebGpuContext and ComputeContextBase.
// ComputeContextBase needs it for FlushAndWait(), which routes through the
// currently-active buffer manager. This narrow allow-list prevents
// arbitrary classes from reaching into BufferManager directly.
friend class WebGpuContext;
friend class ComputeContextBase;

private:
static const webgpu::BufferManager& Get(const ComputeContextBase& context);
Expand Down Expand Up @@ -121,6 +123,11 @@ class ComputeContextBase {
return webgpu_context_.Run(*this, program);
}

inline Status FlushAndWait() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is FlushAndWait() used for?

webgpu_context_.Flush(BufferManagerAccessor::Get(*this));
return webgpu_context_.WaitForQueueIdle();
}
Comment thread
hariharans29 marked this conversation as resolved.

protected:
WebGpuContext& webgpu_context_;
const WebGpuExecutionProvider& ep_;
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ Status WebGpuContext::Wait(wgpu::Future f) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status));
}

Status WebGpuContext::WaitForQueueIdle() {
return Wait(device_queue_.OnSubmittedWorkDone(
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
ORT_ENFORCE(status == wgpu::QueueWorkDoneStatus::Success,
"Failed to wait for submitted WebGPU work: ",
std::string_view{message});
}));
}
Comment thread
hariharans29 marked this conversation as resolved.

Status WebGpuContext::Run(ComputeContextBase& context, const ProgramBase& program) {
const auto& inputs = program.Inputs();
const auto& outputs = program.Outputs();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class WebGpuContextFactory {
class WebGpuContext final {
public:
Status Wait(wgpu::Future f);
Status WaitForQueueIdle();

const wgpu::Device& Device() const { return device_; }

Expand Down
25 changes: 13 additions & 12 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,16 +578,6 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
enable_int64_{config.enable_graph_capture || config.enable_int64},
multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset},
prepack_allocator_{std::make_shared<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), false)} {
// If graph capture is enabled, create a dedicated buffer manager for graph mode
if (enable_graph_capture_) {
// Create buffer manager for graph capture mode with appropriate cache modes
graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create(
context_,
webgpu::BufferCacheMode::Graph,
webgpu::BufferCacheMode::GraphSimple,
webgpu::BufferCacheMode::Disabled);
}

if (config.enable_pix_capture) {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// set pix frame generator
Expand All @@ -603,7 +593,7 @@ std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
// allocator for initializers
std::make_unique<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), true),
// default allocator
std::make_unique<webgpu::GpuBufferAllocator>(BufferManager(), false),
std::make_unique<webgpu::GpuBufferAllocator>([this]() -> const webgpu::BufferManager& { return BufferManager(); }, false),
};
}

Expand Down Expand Up @@ -773,6 +763,14 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
}

if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) {
if (!graph_buffer_mgr_) {
graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create(
context_,
webgpu::BufferCacheMode::Graph,
webgpu::BufferCacheMode::GraphSimple,
webgpu::BufferCacheMode::Disabled);
}
graph_buffer_mgr_active_ = true;
context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_);
}
m_current_graph_annotation_id = graph_annotation_id;
Expand All @@ -794,6 +792,8 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
}
}

graph_buffer_mgr_active_ = false;

if (session_profiler_ && session_profiler_->Enabled()) {
// Session-level profiling: collect into profiler's own events storage.
context_.CollectProfilingData(session_profiler_->GpuEvents());
Expand Down Expand Up @@ -825,6 +825,7 @@ bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {

Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
ORT_ENFORCE(IsGraphCaptured(graph_annotation_id));
ORT_ENFORCE(graph_buffer_mgr_ != nullptr, "Graph buffer manager must exist before replay.");
// TODO: enable profiling in run level
if (session_profiler_ && session_profiler_->Enabled()) {
context_.StartProfiling();
Expand All @@ -838,7 +839,7 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
}

webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const {
if (graph_buffer_mgr_) {
if (graph_buffer_mgr_active_ && graph_buffer_mgr_) {
return *graph_buffer_mgr_;
} else {
return context_.BufferManager();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class WebGpuExecutionProvider : public IExecutionProvider {
bool enable_int64_ = false;
uint32_t multi_rotary_cache_concat_offset_ = 0;
bool is_graph_captured_ = false;
bool graph_buffer_mgr_active_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // Required regular runs before graph capture for any necessary allocations.
int m_current_graph_annotation_id = 0;
Expand Down
Loading