@@ -105,6 +105,7 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) {
105105#define (SHADER )
106106diagnostic (off , chromium . subgroup_matrix_uniformity );
107107enable f16 ;
108+ enable subgroups ;
108109enable chromium_experimental_subgroup_matrix ;
109110
110111struct MulMatParams {
@@ -138,7 +139,11 @@ DECLS
138139// current Dawn version type definitions/matrix load requirements for constant memory sizes.
139140const SUBGROUP_M = {{WEBGPU_SUBGROUP_M }}u ;
140141const SUBGROUP_N = {{WEBGPU_SUBGROUP_N }}u ;
141- const SUBGROUP_SIZE = {{WEBGPU_SUBGROUP_SIZE }}u ;
142+ // For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
143+ // runtime subgroup size is smaller.
144+ const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE }}u ;
145+
146+ const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N ;
142147
143148const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE }}u ;
144149const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE }}u ;
@@ -152,7 +157,7 @@ const TILE_K = {{WEBGPU_TILE_K}}u;
152157const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE ;
153158const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE ;
154159
155- const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE ;
160+ const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE ;
156161const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE ;
157162const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE ;
158163
@@ -164,23 +169,21 @@ const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
164169var <workgroup > shmem : array <f16 , SHMEM_SIZE >;
165170
166171@compute @workgroup_size (TOTAL_WORKGROUP_SIZE )
167- fn main (@builtin (global_invocation_id ) global_id : vec3 <u32 >,
172+ fn main (@builtin (workgroup_id ) wg_id : vec3 <u32 >,
168173 @builtin (local_invocation_id ) local_id : vec3 <u32 >,
169174 @builtin (subgroup_id ) subgroup_id : u32 ) {
170175
171176 let thread_id = local_id . x ;
172177 let subgroup_m = subgroup_id % SUBGROUP_M ;
173178 let subgroup_n = subgroup_id / SUBGROUP_M ;
174179
175- let wg_linear = global_id . x / TOTAL_WORKGROUP_SIZE ;
176-
177180 let wg_m_count = (params . m + WG_M_SG_TILE_SIZE - 1 ) / WG_M_SG_TILE_SIZE ;
178181 let wg_n_count = (params . n + WG_N_SG_TILE_SIZE - 1 ) / WG_N_SG_TILE_SIZE ;
179182 let wg_per_matrix = wg_m_count * wg_n_count ;
180183
181- let batch_idx = wg_linear / wg_per_matrix ;
184+ let batch_idx = wg_id . x / wg_per_matrix ;
182185
183- let wg_in_batch = wg_linear % wg_per_matrix ;
186+ let wg_in_batch = wg_id . x % wg_per_matrix ;
184187 let wg_m = wg_in_batch % wg_m_count ;
185188 let wg_n = wg_in_batch / wg_m_count ;
186189
@@ -230,29 +233,32 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>,
230233
231234 workgroupBarrier ();
232235
233- for ( var k_inner = 0u ; k_inner < TILE_K ; k_inner += SUBGROUP_MATRIX_K_SIZE ) {
236+ if ( subgroup_id < EXPECTED_SUBGROUPS ) {
234237
235- let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner ;
236- var src0_sg_mats : array <subgroup_matrix_left <f16 , SUBGROUP_MATRIX_K_SIZE , SUBGROUP_MATRIX_M_SIZE >, SUBGROUP_MATRIX_M >;
237- for (var m = 0u ; m < SUBGROUP_MATRIX_M ; m ++ ) {
238- src0_sg_mats [m ] = subgroupMatrixLoad <subgroup_matrix_left <f16 , SUBGROUP_MATRIX_K_SIZE , SUBGROUP_MATRIX_M_SIZE >>(
239- & shmem ,
240- src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K ,
241- false ,
242- TILE_K
243- );
244- }
238+ for (var k_inner = 0u ; k_inner < TILE_K ; k_inner += SUBGROUP_MATRIX_K_SIZE ) {
245239
246- let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner ;
247- for (var n = 0u ; n < SUBGROUP_MATRIX_N ; n ++ ) {
248- let src1_sg_mat = subgroupMatrixLoad <subgroup_matrix_right <f16 , SUBGROUP_MATRIX_N_SIZE , SUBGROUP_MATRIX_K_SIZE >>(
249- & shmem ,
250- TILE_SRC0_SHMEM + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K ,
251- true ,
252- TILE_K
253- );
240+ let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner ;
241+ var src0_sg_mats : array <subgroup_matrix_left <f16 , SUBGROUP_MATRIX_K_SIZE , SUBGROUP_MATRIX_M_SIZE >, SUBGROUP_MATRIX_M >;
254242 for (var m = 0u ; m < SUBGROUP_MATRIX_M ; m ++ ) {
255- acc_sg_mat [m ][n ] = subgroupMatrixMultiplyAccumulate (src0_sg_mats [m ], src1_sg_mat , acc_sg_mat [m ][n ]);
243+ src0_sg_mats [m ] = subgroupMatrixLoad <subgroup_matrix_left <f16 , SUBGROUP_MATRIX_K_SIZE , SUBGROUP_MATRIX_M_SIZE >>(
244+ & shmem ,
245+ src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K ,
246+ false ,
247+ TILE_K
248+ );
249+ }
250+
251+ let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner ;
252+ for (var n = 0u ; n < SUBGROUP_MATRIX_N ; n ++ ) {
253+ let src1_sg_mat = subgroupMatrixLoad <subgroup_matrix_right <f16 , SUBGROUP_MATRIX_N_SIZE , SUBGROUP_MATRIX_K_SIZE >>(
254+ & shmem ,
255+ TILE_SRC0_SHMEM + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K ,
256+ true ,
257+ TILE_K
258+ );
259+ for (var m = 0u ; m < SUBGROUP_MATRIX_M ; m ++ ) {
260+ acc_sg_mat [m ][n ] = subgroupMatrixMultiplyAccumulate (src0_sg_mats [m ], src1_sg_mat , acc_sg_mat [m ][n ]);
261+ }
256262 }
257263 }
258264 }
@@ -268,12 +274,14 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>,
268274 let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE ;
269275 let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE ;
270276
271- for (var n = 0u ; n < SUBGROUP_MATRIX_N ; n ++ ) {
272- for (var m = 0u ; m < SUBGROUP_MATRIX_M ; m ++ ) {
273- let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE ;
274- let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE ;
275- let out_base = local_row * WG_TILE_STRIDE + local_col ;
276- subgroupMatrixStore (& shmem , out_base , acc_sg_mat [m ][n ], true , WG_TILE_STRIDE );
277+ if (subgroup_id < EXPECTED_SUBGROUPS ) { // 2-5% performance hit :(
278+ for (var n = 0u ; n < SUBGROUP_MATRIX_N ; n ++ ) {
279+ for (var m = 0u ; m < SUBGROUP_MATRIX_M ; m ++ ) {
280+ let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE ;
281+ let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE ;
282+ let out_base = local_row * WG_TILE_STRIDE + local_col ;
283+ subgroupMatrixStore (& shmem , out_base , acc_sg_mat [m ][n ], true , WG_TILE_STRIDE );
284+ }
277285 }
278286 }
279287
0 commit comments