1313
1414#include < condition_variable>
1515#include < cstring>
16+ #include < iostream>
1617#include < mutex>
18+ #include < string>
1719#include < vector>
1820
1921#ifdef GGML_WEBGPU_DEBUG
@@ -61,7 +63,8 @@ struct webgpu_param_bufs {
6163struct webgpu_param_buf_pool {
6264 std::vector<webgpu_param_bufs> free;
6365
64- std::mutex mutex;
66+ std::mutex mutex;
67+
6568 std::condition_variable cv;
6669
6770 void init (wgpu::Device device) {
@@ -108,19 +111,18 @@ struct webgpu_param_buf_pool {
108111
109112// All the base objects needed to run operations on a WebGPU device
110113struct webgpu_context_struct {
111- wgpu::Instance instance;
112- wgpu::Adapter adapter;
113- wgpu::Device device;
114- wgpu::Queue queue;
115- wgpu::Limits limits;
116- wgpu::SupportedFeatures features;
117-
118- std::recursive_mutex submit_mutex;
114+ wgpu::Instance instance;
115+ wgpu::Adapter adapter;
116+ wgpu::Device device;
117+ wgpu::Queue queue;
118+ wgpu::Limits limits;
119+
120+ std::recursive_mutex mutex;
119121 std::mutex get_tensor_mutex;
120122 std::mutex init_mutex;
121- bool device_init = false ;
122123
123- // Parameter buffer pool
124+ bool device_init = false ;
125+
124126 webgpu_param_buf_pool param_buf_pool;
125127
126128 wgpu::ComputePipeline memset_pipeline;
@@ -134,36 +136,33 @@ struct webgpu_context_struct {
134136
135137 // Command buffers which need to be submitted
136138 std::vector<wgpu::CommandBuffer> staged_command_bufs;
139+
137140 // Parameter buffers associated with the staged command buffers
138- std::vector<webgpu_param_bufs> staged_param_bufs;
141+ std::vector<webgpu_param_bufs> staged_param_bufs;
139142};
140143
141144typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
142145
143146struct ggml_backend_webgpu_reg_context {
144147 webgpu_context webgpu_ctx;
145-
146- size_t device_count;
147- const char * name;
148+ size_t device_count;
149+ const char * name;
148150};
149151
150152struct ggml_backend_webgpu_device_context {
151153 webgpu_context webgpu_ctx;
152-
153- std::string device_name;
154- std::string device_desc;
154+ std::string device_name;
155+ std::string device_desc;
155156};
156157
157158struct ggml_backend_webgpu_context {
158159 webgpu_context webgpu_ctx;
159-
160- std::string name;
160+ std::string name;
161161};
162162
163163struct ggml_backend_webgpu_buffer_context {
164164 webgpu_context webgpu_ctx;
165-
166- wgpu::Buffer buffer;
165+ wgpu::Buffer buffer;
167166
168167 ggml_backend_webgpu_buffer_context (webgpu_context ctx, wgpu::Buffer buf) :
169168 webgpu_ctx (std::move(ctx)),
@@ -180,10 +179,13 @@ static void ggml_webgpu_create_pipeline(wgpu::Device &
180179 const char * label,
181180 const std::vector<wgpu::ConstantEntry> & constants = {}) {
182181 WEBGPU_LOG_DEBUG (" ggml_webgpu_create_pipeline()" );
182+
183183 wgpu::ShaderSourceWGSL shader_source;
184184 shader_source.code = shader_code;
185+
185186 wgpu::ShaderModuleDescriptor shader_desc;
186- shader_desc.nextInChain = &shader_source;
187+ shader_desc.nextInChain = &shader_source;
188+
187189 wgpu::ShaderModule shader_module = device.CreateShaderModule (&shader_desc);
188190
189191 wgpu::ComputePipelineDescriptor pipeline_desc;
@@ -210,8 +212,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device,
210212 buffer_desc.usage = usage;
211213 buffer_desc.label = label;
212214 buffer_desc.mappedAtCreation = false ;
215+
213216 // TODO: error handling
214- buffer = device.CreateBuffer (&buffer_desc);
217+ buffer = device.CreateBuffer (&buffer_desc);
215218}
216219
217220/* * End WebGPU object initializations */
@@ -231,8 +234,7 @@ static void ggml_backend_webgpu_wait_on_submission(webgpu_context & ctx) {
231234}
232235
233236static void ggml_backend_webgpu_submit_queue (webgpu_context & ctx) {
234- std::lock_guard<std::recursive_mutex> lock (ctx->submit_mutex );
235-
237+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
236238 ctx->queue .Submit (ctx->staged_command_bufs .size (), ctx->staged_command_bufs .data ());
237239 ctx->staged_command_bufs .clear ();
238240 std::vector<webgpu_param_bufs> staged_param_bufs = std::move (ctx->staged_param_bufs );
@@ -274,6 +276,8 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
274276 bool submit_imm = false ) {
275277 webgpu_param_bufs params_bufs = ctx->param_buf_pool .alloc_bufs ();
276278
279+ std::lock_guard<std::recursive_mutex> lock (ctx->mutex );
280+
277281 ggml_backend_webgpu_map_buffer (ctx, params_bufs.host_buf , wgpu::MapMode::Write, 0 , params_bufs.host_buf .GetSize ());
278282 uint32_t * _params = (uint32_t *) params_bufs.host_buf .GetMappedRange ();
279283 for (size_t i = 0 ; i < params.size (); i++) {
@@ -315,7 +319,6 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
315319 });
316320 } else {
317321 // Enqueue commands and only submit if we have enough staged commands
318- std::lock_guard<std::recursive_mutex> lock (ctx->submit_mutex );
319322 ctx->staged_command_bufs .push_back (commands);
320323 ctx->staged_param_bufs .push_back (params_bufs);
321324 if (ctx->staged_command_bufs .size () == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
@@ -540,10 +543,12 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe
540543 WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_memset_tensor(" << buffer << " , " << tensor << " , " << value << " , "
541544 << offset << " , " << size << " )" );
542545
543- ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context ;
544- size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
546+ ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context ;
547+
548+ size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
549+
545550 // This is a trick to set all bytes of a u32 to the same 1 byte value.
546- uint32_t val32 = (uint32_t ) value * 0x01010101 ;
551+ uint32_t val32 = (uint32_t ) value * 0x01010101 ;
547552 ggml_backend_webgpu_buffer_memset (buf_ctx->webgpu_ctx , buf_ctx->buffer , val32, total_offset, size);
548553}
549554
@@ -559,13 +564,16 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
559564
560565 size_t total_offset = webgpu_tensor_offset (tensor) + tensor->view_offs + offset;
561566
567+ std::lock_guard<std::recursive_mutex> lock (webgpu_ctx->mutex );
562568 webgpu_ctx->queue .WriteBuffer (buf_ctx->buffer , total_offset, data, (size / 4 ) * 4 );
563569
564570 if (size % 4 != 0 ) {
565571 // If size is not a multiple of 4, we need to memset the remaining bytes
566- size_t remaining_size = size % 4 ;
572+ size_t remaining_size = size % 4 ;
573+
567574 // pack the remaining bytes into a uint32_t
568- uint32_t val32 = 0 ;
575+ uint32_t val32 = 0 ;
576+
569577 for (size_t i = 0 ; i < remaining_size; i++) {
570578 ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
571579 }
@@ -613,8 +621,12 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
613621 wgpu::CommandEncoder encoder = device.CreateCommandEncoder ();
614622 encoder.CopyBufferToBuffer (buf_ctx->buffer , total_offset, webgpu_ctx->get_tensor_staging_buf , 0 , final_size);
615623 wgpu::CommandBuffer commands = encoder.Finish ();
616- // Submit the command buffer to the queue
617- webgpu_ctx->queue .Submit (1 , &commands);
624+
625+ {
626+ std::lock_guard<std::recursive_mutex> submit_lock (webgpu_ctx->mutex );
627+ // Submit the command buffer to the queue
628+ webgpu_ctx->queue .Submit (1 , &commands);
629+ }
618630
619631 // Map the staging buffer to read the data
620632 ggml_backend_webgpu_map_buffer (webgpu_ctx, webgpu_ctx->get_tensor_staging_buf , wgpu::MapMode::Read, 0 , final_size);
@@ -628,7 +640,6 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
628640
629641static void ggml_backend_webgpu_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
630642 WEBGPU_LOG_DEBUG (" ggml_backend_webgpu_buffer_clear(" << buffer << " , " << (uint32_t ) value << " )" );
631-
632643 ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context ;
633644 ggml_backend_webgpu_buffer_memset (buf_ctx->webgpu_ctx , buf_ctx->buffer , value, 0 , buffer->size );
634645}
@@ -764,10 +775,11 @@ static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, co
764775 std::lock_guard<std::mutex> lock (webgpu_ctx->init_mutex );
765776 if (!webgpu_ctx->device_init ) {
766777 // Initialize device
767- wgpu::DeviceDescriptor dev_desc;
778+ std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
779+ wgpu::DeviceDescriptor dev_desc;
768780 dev_desc.requiredLimits = &webgpu_ctx->limits ;
769- dev_desc.requiredFeatures = webgpu_ctx-> features . features ;
770- dev_desc.requiredFeatureCount = webgpu_ctx-> features . featureCount ;
781+ dev_desc.requiredFeatures = required_features. data () ;
782+ dev_desc.requiredFeatureCount = required_features. size () ;
771783 dev_desc.SetDeviceLostCallback (
772784 wgpu::CallbackMode::AllowSpontaneous,
773785 [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
@@ -920,7 +932,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
920932 GGML_ASSERT (ctx->adapter != nullptr );
921933
922934 ctx->adapter .GetLimits (&ctx->limits );
923- ctx->adapter .GetFeatures (&ctx->features );
924935
925936 wgpu::AdapterInfo info{};
926937 ctx->adapter .GetInfo (&info);
0 commit comments