@@ -248,7 +248,7 @@ struct webgpu_context_struct {
248248
249249 webgpu_pipeline memset_pipeline;
250250 webgpu_pipeline mul_mat_pipeline[30 ][2 ];
251- webgpu_pipeline set_rows_pipeline[1 ][2 ]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized)
251+ webgpu_pipeline set_rows_pipeline[1 ][2 ]; // dst->type, vectorized (0 for vectorized, 1 for non vectorized)
252252 webgpu_pipeline get_rows_pipeline[30 ];
253253 webgpu_pipeline get_rows_f32_no_vec_pipeline;
254254 webgpu_pipeline cpy_pipeline[2 ][2 ]; // src type, dst type
@@ -766,15 +766,15 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
766766 { .binding = 3 , .buffer = error_bufs.dev_buf , .offset = 0 , .size = error_bufs.dev_buf .GetSize () }
767767 };
768768
769- size_t max_wg_size = ctx->max_wg_size_x ;
770- // number of threads needed with vec4 = (total number of rows in matrix) * (number of elements in a row / 4)
771- uint32_t threads = (src->ne [1 ] * src->ne [2 ] * src->ne [3 ]) * (src->ne [0 ] / 4 );
769+ size_t max_wg_size = ctx->max_wg_size_x ;
772770
773- webgpu_pipeline pipeline = ctx->set_rows_pipeline [0 ][0 ];
771+ int vectorized = src->ne [0 ] % 4 == 0 ;
772+ webgpu_pipeline pipeline = ctx->set_rows_pipeline [0 ][vectorized];
774773 // if not evenly divisble by 4, use the non-vectorized version
775- if (src->ne [0 ] % 4 != 0 ) {
776- pipeline = ctx->set_rows_pipeline [0 ][1 ];
777- // threads = number of elements
774+ uint32_t threads;
775+ if (vectorized) {
776+ threads = (src->ne [1 ] * src->ne [2 ] * src->ne [3 ]) * (src->ne [0 ] / 4 );
777+ } else {
778778 threads = src->ne [0 ] * src->ne [1 ] * src->ne [2 ] * src->ne [3 ];
779779 }
780780
@@ -1631,11 +1631,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
16311631}
16321632
16331633static void ggml_webgpu_init_set_rows_pipeline (webgpu_context & webgpu_ctx) {
1634- // create_pipeline(device, pipeline, shader_code, label, constants)
1635- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_pipeline [0 ][1 ], wgsl_set_rows_f16, " set_rows_f16" ,
1636- ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x ));
1637- ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_pipeline [0 ][0 ], wgsl_set_rows_f16_vec, " set_rows_f16_vec" ,
1638- ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x ));
1634+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_pipeline [0 ][0 ], wgsl_set_rows_f16,
1635+ " set_rows_f16" , ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x ));
1636+ ggml_webgpu_create_pipeline (webgpu_ctx->device , webgpu_ctx->set_rows_pipeline [0 ][1 ], wgsl_set_rows_f16_vec,
1637+ " set_rows_f16_vec" , ggml_webgpu_wg_size_entry (webgpu_ctx->max_wg_size_x ));
16391638}
16401639
16411640static void ggml_webgpu_init_get_rows_pipeline (webgpu_context & webgpu_ctx) {
0 commit comments