Skip to content

Commit 4ec09e4

Browse files
committed
Working q4_0
1 parent eb7150a commit 4ec09e4

File tree

5 files changed

+246
-74
lines changed

5 files changed

+246
-74
lines changed

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@
9292
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
9393
#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
9494

95-
// gemv parameters
96-
#define WEBGPU_GEMV_WG_SIZE 256
97-
// Must be multiple of 4 to work with vectorized paths, and must divide gemv wg size
98-
#define WEBGPU_GEMV_OUTPUTS_PER_WG 16
99-
#define WEBGPU_GEMV_TILE_K 128
95+
// Matrix-vector multiplication parameters
96+
#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
97+
// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
98+
#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
99+
#define WEBGPU_MUL_MAT_VEC_TILE_K 256
100100

101101
/* End Constants */
102102

@@ -278,7 +278,7 @@ struct webgpu_context_struct {
278278
webgpu_pipeline memset_pipeline;
279279

280280
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
281-
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> gemv_pipelines; // src0_type, src1_type, vectorized
281+
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
282282

283283
webgpu_pipeline mul_mat_pipeline[30][2];
284284
webgpu_pipeline set_rows_pipeline[1][2]; // dst->type, vectorized
@@ -957,6 +957,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
957957
switch (src0->type) {
958958
case GGML_TYPE_F32:
959959
case GGML_TYPE_F16:
960+
case GGML_TYPE_Q4_0:
960961
use_fast = true;
961962
break;
962963
default:
@@ -970,9 +971,11 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
970971
if (use_fast) {
971972
int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
972973
if (dst->ne[1] == 1) {
973-
pipeline = ctx->gemv_pipelines[src0->type][src1->type][vectorized];
974+
// We don't support vectorized mul_mat_vec for quantized types
975+
vectorized = vectorized && (src0->type < 2);
976+
pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
974977
uint32_t batches = dst->ne[2] * dst->ne[3];
975-
uint32_t output_groups = (dst->ne[0] + WEBGPU_GEMV_OUTPUTS_PER_WG - 1) / WEBGPU_GEMV_OUTPUTS_PER_WG;
978+
uint32_t output_groups = (dst->ne[0] + WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG - 1) / WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
976979
uint32_t total_wg = output_groups * batches;
977980
wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
978981
wg_y = (total_wg + ctx->limits.maxComputeWorkgroupsPerDimension - 1) /
@@ -1777,6 +1780,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17771780
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
17781781
std::string proc_mul_mat_subgroup_matrix_f16_f16_vec =
17791782
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
1783+
std::string proc_mul_mat_subgroup_matrix_q4_0_f32 =
1784+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
1785+
std::string proc_mul_mat_subgroup_matrix_q4_0_f32_vec =
1786+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
17801787

17811788
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
17821789
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f32_f32.c_str(), "mul_mat_subgroup_matrix_f32_f32");
@@ -1793,6 +1800,11 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17931800
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
17941801
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_f16_f16_vec.c_str(),
17951802
"mul_mat_subgroup_matrix_f16_f16_vec");
1803+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline2(
1804+
webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32.c_str(), "mul_mat_subgroup_matrix_q4_0_f32");
1805+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
1806+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_subgroup_matrix_q4_0_f32_vec.c_str(),
1807+
"mul_mat_subgroup_matrix_q4_0_f32_vec");
17961808
} else {
17971809
std::vector<wgpu::ConstantEntry> mul_mat_reg_tile_constants(3);
17981810
mul_mat_reg_tile_constants[0].key = "TILE_K";
@@ -1820,6 +1832,10 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
18201832
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
18211833
std::string proc_mul_mat_reg_tile_f16_f16_vec =
18221834
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
1835+
std::string proc_mul_mat_reg_tile_q4_0_f32 =
1836+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
1837+
std::string proc_mul_mat_reg_tile_q4_0_f32_vec =
1838+
ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
18231839

18241840
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
18251841
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f32_f32.c_str(),
@@ -1839,28 +1855,37 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
18391855
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
18401856
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_f16_f16_vec.c_str(),
18411857
"mul_mat_reg_tile_f16_f16_vec", mul_mat_reg_tile_constants);
1858+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
1859+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32.c_str(),
1860+
"mul_mat_reg_tile_q4_0_f32", mul_mat_reg_tile_constants);
1861+
webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] =
1862+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, proc_mul_mat_reg_tile_q4_0_f32_vec.c_str(),
1863+
"mul_mat_reg_tile_q4_0_f32_vec", mul_mat_reg_tile_constants);
1864+
18421865
}
18431866

1844-
std::vector<wgpu::ConstantEntry> gemv_constants(3);
1845-
gemv_constants[0].key = "WORKGROUP_SIZE";
1846-
gemv_constants[0].value = WEBGPU_GEMV_WG_SIZE;
1847-
gemv_constants[1].key = "TILE_K";
1848-
gemv_constants[1].value = WEBGPU_GEMV_TILE_K;
1849-
gemv_constants[2].key = "OUTPUTS_PER_WG";
1850-
gemv_constants[2].value = WEBGPU_GEMV_OUTPUTS_PER_WG;
1851-
1852-
webgpu_ctx->gemv_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
1853-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f32_f32, "gemv_f32_f32", gemv_constants);
1854-
webgpu_ctx->gemv_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
1855-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f32_f32_vec, "gemv_f32_f32_vec", gemv_constants);
1856-
webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
1857-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f32, "gemv_f16_f32", gemv_constants);
1858-
webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
1859-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f32_vec, "gemv_f16_f32_vec", gemv_constants);
1860-
webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
1861-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f16, "gemv_f16_f16", gemv_constants);
1862-
webgpu_ctx->gemv_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
1863-
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_gemv_f16_f16_vec, "gemv_f16_f16_vec", gemv_constants);
1867+
std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
1868+
mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
1869+
mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
1870+
mul_mat_vec_constants[1].key = "TILE_K";
1871+
mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
1872+
mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
1873+
mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
1874+
1875+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] =
1876+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
1877+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] =
1878+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
1879+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] =
1880+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
1881+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] =
1882+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
1883+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] =
1884+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
1885+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] =
1886+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
1887+
webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
1888+
ggml_webgpu_create_pipeline2(webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
18641889
}
18651890

18661891
static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {

ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ fn store_shmem(val: f16, idx: u32) {
1313
}
1414
#enddecl(SHMEM_SCALAR)
1515

16-
#decl(INIT_SHMEM_FLOAT)
16+
#decl(INIT_SRC0_SHMEM_FLOAT)
1717

1818
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
1919
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
@@ -30,6 +30,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
3030
}
3131
}
3232

33+
#enddecl(INIT_SRC0_SHMEM_FLOAT)
34+
35+
#decl(INIT_SRC1_SHMEM)
36+
3337
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
3438
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
3539
let tile_n = elem_idx / TILE_K;
@@ -45,5 +49,50 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3
4549
}
4650
}
4751

48-
#enddecl(INIT_SHMEM_FLOAT)
52+
#enddecl(INIT_SRC1_SHMEM)
53+
54+
#decl(INIT_SRC0_SHMEM_Q4_0)
55+
56+
const BLOCK_SIZE = 32u;
57+
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
58+
override BLOCKS_K = TILE_K/BLOCK_SIZE;
59+
const NQ = 16u;
60+
const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights
61+
const WEIGHTS_PER_F16 = 4u; // 4 weights per f16
62+
const F16_PER_THREAD = NQ / WEIGHTS_PER_F16;
63+
64+
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
65+
for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) {
66+
let blck_idx = i / BLOCK_SIZE;
67+
let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16;
68+
let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u;
69+
70+
let tile_m = blck_idx / BLOCKS_K;
71+
let global_m = offset_m + tile_m;
72+
let block_k = blck_idx % BLOCKS_K;
73+
let global_k = k_outer / BLOCK_SIZE + block_k;
74+
75+
if (global_m < params.m && global_k < params.k / BLOCK_SIZE) {
76+
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
77+
let scale_idx = src0_idx * F16_PER_BLOCK;
78+
let d = src0[scale_idx];
79+
80+
for (var j = 0u; j < F16_PER_THREAD; j += 2) {
81+
let q_0 = src0[scale_idx + 1u + block_offset + j];
82+
let q_1 = src0[scale_idx + 1u + block_offset + j + 1];
83+
84+
let q_packed = bitcast<u32>(vec2(q_0, q_1));
85+
for (var k = 0u; k < 4u; k++) {
86+
let q_byte = get_byte(q_packed, k);
87+
let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
88+
let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
89+
shmem[shmem_idx + j * 2 + k] = q_lo;
90+
shmem[shmem_idx + j * 2 + k + 16u] = q_hi;
91+
}
92+
}
93+
}
94+
}
95+
}
96+
97+
#enddecl(INIT_SRC0_SHMEM_Q4_0)
4998

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
"SRC1_TYPE" : "vec4<f32>",
88
"DST_TYPE" : "vec4<f32>",
99
"SHMEM_TYPE" : "vec4<f16>",
10-
"VEC_SIZE" : "4",
10+
"VEC_SIZE" : 4,
1111
},
12-
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"]
12+
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
1313
},
1414
{
1515
"SHADER_SUFFIX": "f32_f32",
@@ -18,9 +18,9 @@
1818
"SRC1_TYPE" : "f32",
1919
"DST_TYPE" : "f32",
2020
"SHMEM_TYPE" : "f16",
21-
"VEC_SIZE" : "1",
21+
"VEC_SIZE" : 1,
2222
},
23-
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
23+
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
2424
},
2525
{
2626
"SHADER_SUFFIX": "f16_f32_vec",
@@ -29,9 +29,9 @@
2929
"SRC1_TYPE" : "vec4<f32>",
3030
"DST_TYPE" : "vec4<f32>",
3131
"SHMEM_TYPE" : "vec4<f16>",
32-
"VEC_SIZE" : "4",
32+
"VEC_SIZE" : 4,
3333
},
34-
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"]
34+
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
3535
},
3636
{
3737
"SHADER_SUFFIX": "f16_f32",
@@ -40,9 +40,9 @@
4040
"SRC1_TYPE" : "f32",
4141
"DST_TYPE" : "f32",
4242
"SHMEM_TYPE" : "f16",
43-
"VEC_SIZE" : "1",
43+
"VEC_SIZE" : 1,
4444
},
45-
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
45+
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
4646
},
4747
{
4848
"SHADER_SUFFIX": "f16_f16_vec",
@@ -51,9 +51,9 @@
5151
"SRC1_TYPE" : "vec4<f16>",
5252
"DST_TYPE" : "vec4<f32>",
5353
"SHMEM_TYPE" : "vec4<f16>",
54-
"VEC_SIZE" : "4",
54+
"VEC_SIZE" : 4,
5555
},
56-
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"]
56+
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
5757
},
5858
{
5959
"SHADER_SUFFIX": "f16_f16",
@@ -62,9 +62,31 @@
6262
"SRC1_TYPE" : "f16",
6363
"DST_TYPE" : "f32",
6464
"SHMEM_TYPE" : "f16",
65-
"VEC_SIZE" : "1",
65+
"VEC_SIZE" : 1,
6666
},
67-
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
67+
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
68+
},
69+
{
70+
"SHADER_SUFFIX": "q4_0_f32_vec",
71+
"REPLS": {
72+
"SRC0_TYPE" : "f16",
73+
"SRC1_TYPE" : "vec4<f32>",
74+
"DST_TYPE" : "vec4<f32>",
75+
"SHMEM_TYPE" : "vec4<f16>",
76+
"VEC_SIZE" : 4,
77+
},
78+
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
79+
},
80+
{
81+
"SHADER_SUFFIX": "q4_0_f32",
82+
"REPLS": {
83+
"SRC0_TYPE" : "f16",
84+
"SRC1_TYPE" : "f32",
85+
"DST_TYPE" : "f32",
86+
"SHMEM_TYPE" : "f16",
87+
"VEC_SIZE" : 1,
88+
},
89+
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
6890
}
6991
]
7092

0 commit comments

Comments
 (0)