Skip to content

Commit 32ca54e

Browse files
committed
Change logic for set_rows pipelines
1 parent 2cc96eb commit 32ca54e

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ struct webgpu_context_struct {
248248

249249
webgpu_pipeline memset_pipeline;
250250
webgpu_pipeline mul_mat_pipeline[30][2];
251-
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized)
251+
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized)
252252
webgpu_pipeline get_rows_pipeline[30];
253253
webgpu_pipeline get_rows_f32_no_vec_pipeline;
254254
webgpu_pipeline cpy_pipeline[2][2]; // src type, dst type
@@ -766,15 +766,15 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
766766
{ .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
767767
};
768768

769-
size_t max_wg_size = ctx->max_wg_size_x;
770-
// number of threads needed with vec4 = (total number of rows in matrix) * (number of elements in a row / 4)
771-
uint32_t threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
769+
size_t max_wg_size = ctx->max_wg_size_x;
772770

773-
webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][0];
771+
int vectorized = src->ne[0] % 4 == 0;
772+
webgpu_pipeline pipeline = ctx->set_rows_pipeline[0][vectorized];
774773
// if not evenly divisble by 4, use the non-vectorized version
775-
if (src->ne[0] % 4 != 0) {
776-
pipeline = ctx->set_rows_pipeline[0][1];
777-
// threads = number of elements
774+
uint32_t threads;
775+
if (vectorized) {
776+
threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
777+
} else {
778778
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
779779
}
780780

@@ -1631,11 +1631,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
16311631
}
16321632

16331633
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
1634-
// create_pipeline(device, pipeline, shader_code, label, constants)
1635-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16, "set_rows_f16",
1636-
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
1637-
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16_vec, "set_rows_f16_vec",
1638-
ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
1634+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][0], wgsl_set_rows_f16,
1635+
"set_rows_f16", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
1636+
ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->set_rows_pipeline[0][1], wgsl_set_rows_f16_vec,
1637+
"set_rows_f16_vec", ggml_webgpu_wg_size_entry(webgpu_ctx->max_wg_size_x));
16391638
}
16401639

16411640
static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {

0 commit comments

Comments
 (0)