9292#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
9393#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
9494
95- // gemv parameters
96- #define WEBGPU_GEMV_WG_SIZE 256
97- // Must be multiple of 4 to work with vectorized paths, and must divide gemv wg size
98- #define WEBGPU_GEMV_OUTPUTS_PER_WG 16
99- #define WEBGPU_GEMV_TILE_K 128
95+ // Matrix-vector multiplication parameters
96+ #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
97+ // Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
98+ #define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
99+ #define WEBGPU_MUL_MAT_VEC_TILE_K 256
100100
101101/* End Constants */
102102
@@ -278,7 +278,7 @@ struct webgpu_context_struct {
278278 webgpu_pipeline memset_pipeline;
279279
280280 std::map<int , std::map<int , std::map<int , webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
281- std::map<int , std::map<int , std::map<int , webgpu_pipeline>>> gemv_pipelines ; // src0_type, src1_type, vectorized
281+ std::map<int , std::map<int , std::map<int , webgpu_pipeline>>> mul_mat_vec_pipelines ; // src0_type, src1_type, vectorized
282282
283283 webgpu_pipeline mul_mat_pipeline[30 ][2 ];
284284 webgpu_pipeline set_rows_pipeline[1 ][2 ]; // dst->type, vectorized
@@ -957,6 +957,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
957957 switch (src0->type ) {
958958 case GGML_TYPE_F32:
959959 case GGML_TYPE_F16:
960+ case GGML_TYPE_Q4_0:
960961 use_fast = true ;
961962 break ;
962963 default :
@@ -970,9 +971,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
970971 if (use_fast) {
971972 int vectorized = src0->ne [0 ] % 4 == 0 && dst->ne [0 ] % 4 == 0 && dst->ne [1 ] % 4 == 0 ;
972973 if (dst->ne [1 ] == 1 ) {
973- pipeline = ctx->gemv_pipelines [src0->type ][src1->type ][vectorized];
974+ // We don't support vectorized mul_mat_vec for quantized types
975+ vectorized = vectorized && (src0->type < 2 );
976+ pipeline = ctx->mul_mat_vec_pipelines [src0->type ][src1->type ][vectorized];
974977 uint32_t batches = dst->ne [2 ] * dst->ne [3 ];
975- uint32_t output_groups = (dst->ne [0 ] + WEBGPU_GEMV_OUTPUTS_PER_WG - 1 ) / WEBGPU_GEMV_OUTPUTS_PER_WG ;
978+ uint32_t output_groups = (dst->ne [0 ] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1 ) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG ;
976979 uint32_t total_wg = output_groups * batches;
977980 wg_x = total_wg % ctx->limits .maxComputeWorkgroupsPerDimension ;
978981 wg_y = (total_wg + ctx->limits .maxComputeWorkgroupsPerDimension - 1 ) /
@@ -1777,6 +1780,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17771780 ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
17781781 std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
17791782 ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
1783+ std::string proc_mul_mat_subgroup_matrix_q4_0_f32 =
1784+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
1785+ std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec =
1786+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
17801787
17811788 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][0 ] = ggml_webgpu_create_pipeline2 (
17821789 webgpu_ctx->device , proc_mul_mat_subgroup_matrix_f32_f32.c_str (), " mul_mat_subgroup_matrix_f32_f32" );
@@ -1793,6 +1800,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17931800 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][1 ] =
17941801 ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str (),
17951802 " mul_mat_subgroup_matrix_f16_f16_vec" );
1803+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_Q4_0][GGML_TYPE_F32][0 ] = ggml_webgpu_create_pipeline2 (
1804+ webgpu_ctx->device , proc_mul_mat_subgroup_matrix_q4_0_f32.c_str (), " mul_mat_subgroup_matrix_q4_0_f32" );
1805+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_Q4_0][GGML_TYPE_F32][1 ] =
1806+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str (),
1807+ " mul_mat_subgroup_matrix_q4_0_f32_vec" );
17961808 } else {
17971809 std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants (3 );
17981810 mul_mat_reg_tile_constants[0 ].key = " TILE_K" ;
@@ -1820,6 +1832,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
18201832 ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
18211833 std::string proc_mul_mat_reg_tile_f16_f16_vec =
18221834 ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
1835+ std::string proc_mul_mat_reg_tile_q4_0_f32 =
1836+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
1837+ std::string proc_mul_mat_reg_tile_q4_0_f32_vec =
1838+ ggml_webgpu_process_shader_repls (wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
18231839
18241840 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][0 ] =
18251841 ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_f32_f32.c_str (),
@@ -1839,28 +1855,37 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
18391855 webgpu_ctx->mul_mat_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][1 ] =
18401856 ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_f16_f16_vec.c_str (),
18411857 " mul_mat_reg_tile_f16_f16_vec" , mul_mat_reg_tile_constants);
1858+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_Q4_0][GGML_TYPE_F32][0 ] =
1859+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_q4_0_f32.c_str (),
1860+ " mul_mat_reg_tile_q4_0_f32" , mul_mat_reg_tile_constants);
1861+ webgpu_ctx->mul_mat_pipelines [GGML_TYPE_Q4_0][GGML_TYPE_F32][1 ] =
1862+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , proc_mul_mat_reg_tile_q4_0_f32_vec.c_str (),
1863+ " mul_mat_reg_tile_q4_0_f32_vec" , mul_mat_reg_tile_constants);
1864+
18421865 }
18431866
1844- std::vector<wgpu::ConstantEntry> gemv_constants (3 );
1845- gemv_constants[0 ].key = " WORKGROUP_SIZE" ;
1846- gemv_constants[0 ].value = WEBGPU_GEMV_WG_SIZE;
1847- gemv_constants[1 ].key = " TILE_K" ;
1848- gemv_constants[1 ].value = WEBGPU_GEMV_TILE_K;
1849- gemv_constants[2 ].key = " OUTPUTS_PER_WG" ;
1850- gemv_constants[2 ].value = WEBGPU_GEMV_OUTPUTS_PER_WG;
1851-
1852- webgpu_ctx->gemv_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][0 ] =
1853- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_gemv_f32_f32, " gemv_f32_f32" , gemv_constants);
1854- webgpu_ctx->gemv_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][1 ] =
1855- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_gemv_f32_f32_vec, " gemv_f32_f32_vec" , gemv_constants);
1856- webgpu_ctx->gemv_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][0 ] =
1857- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_gemv_f16_f32, " gemv_f16_f32" , gemv_constants);
1858- webgpu_ctx->gemv_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][1 ] =
1859- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_gemv_f16_f32_vec, " gemv_f16_f32_vec" , gemv_constants);
1860- webgpu_ctx->gemv_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][0 ] =
1861- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_gemv_f16_f16, " gemv_f16_f16" , gemv_constants);
1862- webgpu_ctx->gemv_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][1 ] =
1863- ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_gemv_f16_f16_vec, " gemv_f16_f16_vec" , gemv_constants);
1867+ std::vector<wgpu::ConstantEntry> mul_mat_vec_constants (3 );
1868+ mul_mat_vec_constants[0 ].key = " WORKGROUP_SIZE" ;
1869+ mul_mat_vec_constants[0 ].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
1870+ mul_mat_vec_constants[1 ].key = " TILE_K" ;
1871+ mul_mat_vec_constants[1 ].value = WEBGPU_MUL_MAT_VEC_TILE_K;
1872+ mul_mat_vec_constants[2 ].key = " OUTPUTS_PER_WG" ;
1873+ mul_mat_vec_constants[2 ].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
1874+
1875+ webgpu_ctx->mul_mat_vec_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][0 ] =
1876+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_vec_f32_f32, " mul_mat_vec_f32_f32" , mul_mat_vec_constants);
1877+ webgpu_ctx->mul_mat_vec_pipelines [GGML_TYPE_F32][GGML_TYPE_F32][1 ] =
1878+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_vec_f32_f32_vec, " mul_mat_vec_f32_f32_vec" , mul_mat_vec_constants);
1879+ webgpu_ctx->mul_mat_vec_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][0 ] =
1880+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_vec_f16_f32, " mul_mat_vec_f16_f32" , mul_mat_vec_constants);
1881+ webgpu_ctx->mul_mat_vec_pipelines [GGML_TYPE_F16][GGML_TYPE_F32][1 ] =
1882+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_vec_f16_f32_vec, " mul_mat_vec_f16_f32_vec" , mul_mat_vec_constants);
1883+ webgpu_ctx->mul_mat_vec_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][0 ] =
1884+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_vec_f16_f16, " mul_mat_vec_f16_f16" , mul_mat_vec_constants);
1885+ webgpu_ctx->mul_mat_vec_pipelines [GGML_TYPE_F16][GGML_TYPE_F16][1 ] =
1886+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_vec_f16_f16_vec, " mul_mat_vec_f16_f16_vec" , mul_mat_vec_constants);
1887+ webgpu_ctx->mul_mat_vec_pipelines [GGML_TYPE_Q4_0][GGML_TYPE_F32][0 ] =
1888+ ggml_webgpu_create_pipeline2 (webgpu_ctx->device , wgsl_mul_mat_vec_q4_0_f32, " mul_mat_vec_q4_0_f32" , mul_mat_vec_constants);
18641889}
18651890
18661891static void ggml_webgpu_init_set_rows_pipeline (webgpu_context & webgpu_ctx) {
0 commit comments