Skip to content

Commit b8012ec

Browse files
committed
Fix thread-safe implementation
1 parent bfff27f commit b8012ec

File tree

1 file changed

+50
-39
lines changed

1 file changed

+50
-39
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
#include <condition_variable>
1515
#include <cstring>
16+
#include <iostream>
1617
#include <mutex>
18+
#include <string>
1719
#include <vector>
1820

1921
#ifdef GGML_WEBGPU_DEBUG
@@ -61,7 +63,8 @@ struct webgpu_param_bufs {
6163
struct webgpu_param_buf_pool {
6264
std::vector<webgpu_param_bufs> free;
6365

64-
std::mutex mutex;
66+
std::mutex mutex;
67+
6568
std::condition_variable cv;
6669

6770
void init(wgpu::Device device) {
@@ -108,19 +111,18 @@ struct webgpu_param_buf_pool {
108111

109112
// All the base objects needed to run operations on a WebGPU device
110113
struct webgpu_context_struct {
111-
wgpu::Instance instance;
112-
wgpu::Adapter adapter;
113-
wgpu::Device device;
114-
wgpu::Queue queue;
115-
wgpu::Limits limits;
116-
wgpu::SupportedFeatures features;
117-
118-
std::recursive_mutex submit_mutex;
114+
wgpu::Instance instance;
115+
wgpu::Adapter adapter;
116+
wgpu::Device device;
117+
wgpu::Queue queue;
118+
wgpu::Limits limits;
119+
120+
std::recursive_mutex mutex;
119121
std::mutex get_tensor_mutex;
120122
std::mutex init_mutex;
121-
bool device_init = false;
122123

123-
// Parameter buffer pool
124+
bool device_init = false;
125+
124126
webgpu_param_buf_pool param_buf_pool;
125127

126128
wgpu::ComputePipeline memset_pipeline;
@@ -134,36 +136,33 @@ struct webgpu_context_struct {
134136

135137
// Command buffers which need to be submitted
136138
std::vector<wgpu::CommandBuffer> staged_command_bufs;
139+
137140
// Parameter buffers associated with the staged command buffers
138-
std::vector<webgpu_param_bufs> staged_param_bufs;
141+
std::vector<webgpu_param_bufs> staged_param_bufs;
139142
};
140143

141144
typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
142145

143146
struct ggml_backend_webgpu_reg_context {
144147
webgpu_context webgpu_ctx;
145-
146-
size_t device_count;
147-
const char * name;
148+
size_t device_count;
149+
const char * name;
148150
};
149151

150152
struct ggml_backend_webgpu_device_context {
151153
webgpu_context webgpu_ctx;
152-
153-
std::string device_name;
154-
std::string device_desc;
154+
std::string device_name;
155+
std::string device_desc;
155156
};
156157

157158
struct ggml_backend_webgpu_context {
158159
webgpu_context webgpu_ctx;
159-
160-
std::string name;
160+
std::string name;
161161
};
162162

163163
struct ggml_backend_webgpu_buffer_context {
164164
webgpu_context webgpu_ctx;
165-
166-
wgpu::Buffer buffer;
165+
wgpu::Buffer buffer;
167166

168167
ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf) :
169168
webgpu_ctx(std::move(ctx)),
@@ -180,10 +179,13 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
180179
const char * label,
181180
const std::vector<wgpu::ConstantEntry> & constants = {}) {
182181
WEBGPU_LOG_DEBUG("ggml_webgpu_create_pipeline()");
182+
183183
wgpu::ShaderSourceWGSL shader_source;
184184
shader_source.code = shader_code;
185+
185186
wgpu::ShaderModuleDescriptor shader_desc;
186-
shader_desc.nextInChain = &shader_source;
187+
shader_desc.nextInChain = &shader_source;
188+
187189
wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
188190

189191
wgpu::ComputePipelineDescriptor pipeline_desc;
@@ -210,8 +212,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
210212
buffer_desc.usage = usage;
211213
buffer_desc.label = label;
212214
buffer_desc.mappedAtCreation = false;
215+
213216
// TODO: error handling
214-
buffer = device.CreateBuffer(&buffer_desc);
217+
buffer = device.CreateBuffer(&buffer_desc);
215218
}
216219

217220
/** End WebGPU object initializations */
@@ -231,8 +234,7 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
231234
}
232235

233236
static void ggml_backend_webgpu_submit_queue(webgpu_context & ctx) {
234-
std::lock_guard<std::recursive_mutex> lock(ctx->submit_mutex);
235-
237+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
236238
ctx->queue.Submit(ctx->staged_command_bufs.size(), ctx->staged_command_bufs.data());
237239
ctx->staged_command_bufs.clear();
238240
std::vector<webgpu_param_bufs> staged_param_bufs = std::move(ctx->staged_param_bufs);
@@ -274,6 +276,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
274276
bool submit_imm = false) {
275277
webgpu_param_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
276278

279+
std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
280+
277281
ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
278282
uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
279283
for (size_t i = 0; i < params.size(); i++) {
@@ -315,7 +319,6 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
315319
});
316320
} else {
317321
// Enqueue commands and only submit if we have enough staged commands
318-
std::lock_guard<std::recursive_mutex> lock(ctx->submit_mutex);
319322
ctx->staged_command_bufs.push_back(commands);
320323
ctx->staged_param_bufs.push_back(params_bufs);
321324
if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
@@ -540,10 +543,12 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
540543
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", "
541544
<< offset << ", " << size << ")");
542545

543-
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
544-
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
546+
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
547+
548+
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
549+
545550
// This is a trick to set all bytes of a u32 to the same 1 byte value.
546-
uint32_t val32 = (uint32_t) value * 0x01010101;
551+
uint32_t val32 = (uint32_t) value * 0x01010101;
547552
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
548553
}
549554

@@ -559,13 +564,16 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
559564

560565
size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
561566

567+
std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
562568
webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
563569

564570
if (size % 4 != 0) {
565571
// If size is not a multiple of 4, we need to memset the remaining bytes
566-
size_t remaining_size = size % 4;
572+
size_t remaining_size = size % 4;
573+
567574
// pack the remaining bytes into a uint32_t
568-
uint32_t val32 = 0;
575+
uint32_t val32 = 0;
576+
569577
for (size_t i = 0; i < remaining_size; i++) {
570578
((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
571579
}
@@ -613,8 +621,12 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
613621
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
614622
encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
615623
wgpu::CommandBuffer commands = encoder.Finish();
616-
// Submit the command buffer to the queue
617-
webgpu_ctx->queue.Submit(1, &commands);
624+
625+
{
626+
std::lock_guard<std::recursive_mutex> submit_lock(webgpu_ctx->mutex);
627+
// Submit the command buffer to the queue
628+
webgpu_ctx->queue.Submit(1, &commands);
629+
}
618630

619631
// Map the staging buffer to read the data
620632
ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
@@ -628,7 +640,6 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
628640

629641
static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
630642
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
631-
632643
ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
633644
ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
634645
}
@@ -764,10 +775,11 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
764775
std::lock_guard<std::mutex> lock(webgpu_ctx->init_mutex);
765776
if (!webgpu_ctx->device_init) {
766777
// Initialize device
767-
wgpu::DeviceDescriptor dev_desc;
778+
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
779+
wgpu::DeviceDescriptor dev_desc;
768780
dev_desc.requiredLimits = &webgpu_ctx->limits;
769-
dev_desc.requiredFeatures = webgpu_ctx->features.features;
770-
dev_desc.requiredFeatureCount = webgpu_ctx->features.featureCount;
781+
dev_desc.requiredFeatures = required_features.data();
782+
dev_desc.requiredFeatureCount = required_features.size();
771783
dev_desc.SetDeviceLostCallback(
772784
wgpu::CallbackMode::AllowSpontaneous,
773785
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
@@ -920,7 +932,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
920932
GGML_ASSERT(ctx->adapter != nullptr);
921933

922934
ctx->adapter.GetLimits(&ctx->limits);
923-
ctx->adapter.GetFeatures(&ctx->features);
924935

925936
wgpu::AdapterInfo info{};
926937
ctx->adapter.GetInfo(&info);

0 commit comments

Comments
 (0)