Skip to content

Commit eb7150a

Browse files
committed
Move sg matrix stores to correct file
1 parent a46d093 commit eb7150a

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

ggml/src/ggml-webgpu/wgsl-shaders/mat_mul_decls.tmpl

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,12 @@ fn store_shmem(val: vec4<f16>, idx: u32) {
55
shmem[idx + 2] = val.z;
66
shmem[idx + 3] = val.w;
77
}
8-
9-
fn store_dst(shmem_idx: u32, dst_idx: u32) {
10-
dst[dst_idx] = vec4<f32>(
11-
f32(shmem[shmem_idx]),
12-
f32(shmem[shmem_idx + 1]),
13-
f32(shmem[shmem_idx + 2]),
14-
f32(shmem[shmem_idx + 3])
15-
);
16-
}
178
#enddecl(SHMEM_VEC)
189

1910
#decl(SHMEM_SCALAR)
2011
fn store_shmem(val: f16, idx: u32) {
2112
shmem[idx] = val;
2213
}
23-
24-
fn store_dst(shmem_idx: u32, dst_idx: u32) {
25-
dst[dst_idx] = f32(shmem[shmem_idx]);
26-
}
2714
#enddecl(SHMEM_SCALAR)
2815

2916
#decl(INIT_SHMEM_FLOAT)

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"SHMEM_TYPE" : "vec4<f16>",
1010
"VEC_SIZE" : "4",
1111
},
12-
"DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"]
12+
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"]
1313
},
1414
{
1515
"SHADER_SUFFIX": "f32_f32",
@@ -20,7 +20,7 @@
2020
"SHMEM_TYPE" : "f16",
2121
"VEC_SIZE" : "1",
2222
},
23-
"DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
23+
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
2424
},
2525
{
2626
"SHADER_SUFFIX": "f16_f32_vec",
@@ -31,7 +31,7 @@
3131
"SHMEM_TYPE" : "vec4<f16>",
3232
"VEC_SIZE" : "4",
3333
},
34-
"DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"]
34+
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"]
3535
},
3636
{
3737
"SHADER_SUFFIX": "f16_f32",
@@ -42,7 +42,7 @@
4242
"SHMEM_TYPE" : "f16",
4343
"VEC_SIZE" : "1",
4444
},
45-
"DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
45+
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
4646
},
4747
{
4848
"SHADER_SUFFIX": "f16_f16_vec",
@@ -53,7 +53,7 @@
5353
"SHMEM_TYPE" : "vec4<f16>",
5454
"VEC_SIZE" : "4",
5555
},
56-
"DECLS": ["SHMEM_VEC", "INIT_SHMEM_FLOAT"]
56+
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SHMEM_FLOAT"]
5757
},
5858
{
5959
"SHADER_SUFFIX": "f16_f16",
@@ -64,12 +64,34 @@
6464
"SHMEM_TYPE" : "f16",
6565
"VEC_SIZE" : "1",
6666
},
67-
"DECLS": ["SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
67+
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SHMEM_FLOAT"]
6868
}
6969
]
7070

7171
#end(VARIANTS)
7272

73+
#define(DECLS)
74+
75+
#decl(VEC)
76+
fn store_dst(shmem_idx: u32, dst_idx: u32) {
77+
dst[dst_idx] = vec4<f32>(
78+
f32(shmem[shmem_idx]),
79+
f32(shmem[shmem_idx + 1]),
80+
f32(shmem[shmem_idx + 2]),
81+
f32(shmem[shmem_idx + 3])
82+
);
83+
}
84+
#enddecl(VEC)
85+
86+
#decl(SCALAR)
87+
fn store_dst(shmem_idx: u32, dst_idx: u32) {
88+
dst[dst_idx] = f32(shmem[shmem_idx]);
89+
}
90+
#enddecl(SCALAR)
91+
92+
93+
#end(DECLS)
94+
7395
#define(SHADER)
7496
diagnostic(off, chromium.subgroup_matrix_uniformity);
7597
enable f16;

0 commit comments

Comments
 (0)