Skip to content

Commit 71c7a4a

Browse files
committed
Update dawn version and move to portable subgroup size
1 parent cf0c536 commit 71c7a4a

File tree

4 files changed

+55
-48
lines changed

4 files changed

+55
-48
lines changed

.github/workflows/build.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,15 @@ jobs:
161161
- name: Dawn Dependency
162162
id: dawn-depends
163163
run: |
164-
DAWN_VERSION="v1.0.0"
164+
DAWN_VERSION="v2.0.0"
165165
DAWN_OWNER="reeselevine"
166166
DAWN_REPO="dawn"
167-
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-macos-latest-Release.tar.gz"
167+
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
168168
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
169-
curl -L -o artifact.tar.gz \
169+
curl -L -o artifact.zip \
170170
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
171171
mkdir dawn
172-
tar -xvf artifact.tar.gz -C dawn --strip-components=1
172+
unzip artifact.zip -d dawn
173173
174174
- name: Build
175175
id: cmake_build
@@ -521,15 +521,15 @@ jobs:
521521
id: dawn-depends
522522
run: |
523523
sudo apt-get install -y libxrandr-dev libxinerama-dev libxcursor-dev mesa-common-dev libx11-xcb-dev libxi-dev
524-
DAWN_VERSION="v1.0.0"
524+
DAWN_VERSION="v2.0.0"
525525
DAWN_OWNER="reeselevine"
526526
DAWN_REPO="dawn"
527-
DAWN_ASSET_NAME="Dawn-a1a6b45cced25a3b7f4fb491e0ae70796cc7f22b-ubuntu-latest-Release.tar.gz"
527+
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
528528
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
529-
curl -L -o artifact.tar.gz \
529+
curl -L -o artifact.zip \
530530
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
531531
mkdir dawn
532-
tar -xvf artifact.tar.gz -C dawn --strip-components=1
532+
unzip artifact.zip -d dawn
533533
534534
- name: Build
535535
id: cmake_build

ggml/src/ggml-webgpu/ggml-webgpu.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,7 +1755,7 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
17551755

17561756
if (webgpu_ctx->supports_subgroup_matrix) {
17571757
std::vector<std::pair<std::string, std::string>> sg_matrix_repls;
1758-
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size));
1758+
sg_matrix_repls.emplace_back("WEBGPU_MAX_SUBGROUP_SIZE", std::to_string(webgpu_ctx->subgroup_size));
17591759
sg_matrix_repls.emplace_back("WEBGPU_TILE_K", std::to_string(WEBGPU_MUL_MAT_TILE_K));
17601760
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_M", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M));
17611761
sg_matrix_repls.emplace_back("WEBGPU_SUBGROUP_N", std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N));
@@ -2398,14 +2398,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
23982398
// Unfortunately, WebGPU allows info.subgroup{Min/Max}Size to be different, and even on devices
23992399
// where it is consistent, e.g., Apple M-series GPUs, the min/max sizes report different values.
24002400
// Therefore, hardcoding the subgroup size to 32 for now for development.
2401-
ctx->subgroup_size = 32;
2401+
ctx->subgroup_size = info.subgroupMaxSize;
24022402
ctx->supports_subgroup_matrix =
24032403
valid_subgroup_matrix_config && ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
24042404

24052405
// Initialize device
24062406
std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16,
24072407
wgpu::FeatureName::ImplicitDeviceSynchronization };
24082408
if (ctx->supports_subgroup_matrix) {
2409+
required_features.push_back(wgpu::FeatureName::Subgroups);
24092410
required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
24102411
}
24112412

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,20 @@ var<workgroup> src0_shmem: array<{{SRC0_TYPE}}, TILE_SRC0_SHMEM/{{VEC_SIZE}}>;
141141
var<workgroup> src1_shmem: array<{{SRC1_TYPE}}, TILE_SRC1_SHMEM/{{VEC_SIZE}}>;
142142

143143
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
144-
fn main(@builtin(global_invocation_id) global_id: vec3<u32>,
144+
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
145145
@builtin(local_invocation_id) local_id: vec3<u32>) {
146146

147147
let thread_id = local_id.x;
148148
let local_m = get_local_m(thread_id);
149149
let local_n = get_local_n(thread_id);
150150

151-
let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE;
152-
153151
let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N);
154152
let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M);
155153
let wg_per_matrix = wg_m_count * wg_n_count;
156154

157-
let batch_idx = wg_linear / wg_per_matrix;
155+
let batch_idx = wg_id.x / wg_per_matrix;
158156

159-
let wg_in_batch = wg_linear % wg_per_matrix;
157+
let wg_in_batch = wg_id.x % wg_per_matrix;
160158
let wg_m = wg_in_batch % wg_m_count;
161159
let wg_n = wg_in_batch / wg_m_count;
162160

ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) {
105105
#define(SHADER)
106106
diagnostic(off, chromium.subgroup_matrix_uniformity);
107107
enable f16;
108+
enable subgroups;
108109
enable chromium_experimental_subgroup_matrix;
109110

110111
struct MulMatParams {
@@ -138,7 +139,11 @@ DECLS
138139
// current Dawn version type definitions/matrix load requirements for constant memory sizes.
139140
const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
140141
const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
141-
const SUBGROUP_SIZE = {{WEBGPU_SUBGROUP_SIZE}}u;
142+
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
143+
// runtime subgroup size is smaller.
144+
const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
145+
146+
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
142147

143148
const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
144149
const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
@@ -152,7 +157,7 @@ const TILE_K = {{WEBGPU_TILE_K}}u;
152157
const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
153158
const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
154159

155-
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * SUBGROUP_SIZE;
160+
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
156161
const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
157162
const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
158163

@@ -164,23 +169,21 @@ const SHMEM_SIZE = max(TILE_SRC0_SHMEM + TILE_SRC1_SHMEM, SG_MAT_ACCUM_SHMEM);
164169
var<workgroup> shmem: array<f16, SHMEM_SIZE>;
165170

166171
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
167-
fn main(@builtin(global_invocation_id) global_id: vec3<u32>,
172+
fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
168173
@builtin(local_invocation_id) local_id: vec3<u32>,
169174
@builtin(subgroup_id) subgroup_id: u32) {
170175

171176
let thread_id = local_id.x;
172177
let subgroup_m = subgroup_id % SUBGROUP_M;
173178
let subgroup_n = subgroup_id / SUBGROUP_M;
174179

175-
let wg_linear = global_id.x / TOTAL_WORKGROUP_SIZE;
176-
177180
let wg_m_count = (params.m + WG_M_SG_TILE_SIZE - 1) / WG_M_SG_TILE_SIZE;
178181
let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE;
179182
let wg_per_matrix = wg_m_count * wg_n_count;
180183

181-
let batch_idx = wg_linear / wg_per_matrix;
184+
let batch_idx = wg_id.x / wg_per_matrix;
182185

183-
let wg_in_batch = wg_linear % wg_per_matrix;
186+
let wg_in_batch = wg_id.x % wg_per_matrix;
184187
let wg_m = wg_in_batch % wg_m_count;
185188
let wg_n = wg_in_batch / wg_m_count;
186189

@@ -230,29 +233,32 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>,
230233

231234
workgroupBarrier();
232235

233-
for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
236+
if (subgroup_id < EXPECTED_SUBGROUPS) {
234237

235-
let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
236-
var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
237-
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
238-
src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
239-
&shmem,
240-
src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
241-
false,
242-
TILE_K
243-
);
244-
}
238+
for (var k_inner = 0u; k_inner < TILE_K; k_inner += SUBGROUP_MATRIX_K_SIZE) {
245239

246-
let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
247-
for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
248-
let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
249-
&shmem,
250-
TILE_SRC0_SHMEM + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
251-
true,
252-
TILE_K
253-
);
240+
let src0_shmem_idx_base = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE * TILE_K + k_inner;
241+
var src0_sg_mats: array<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>, SUBGROUP_MATRIX_M>;
254242
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
255-
acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
243+
src0_sg_mats[m] = subgroupMatrixLoad<subgroup_matrix_left<f16, SUBGROUP_MATRIX_K_SIZE, SUBGROUP_MATRIX_M_SIZE>>(
244+
&shmem,
245+
src0_shmem_idx_base + m * SUBGROUP_MATRIX_M_SIZE * TILE_K,
246+
false,
247+
TILE_K
248+
);
249+
}
250+
251+
let src1_shmem_idx_base = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE * TILE_K + k_inner;
252+
for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
253+
let src1_sg_mat = subgroupMatrixLoad<subgroup_matrix_right<f16, SUBGROUP_MATRIX_N_SIZE, SUBGROUP_MATRIX_K_SIZE>>(
254+
&shmem,
255+
TILE_SRC0_SHMEM + src1_shmem_idx_base + n * SUBGROUP_MATRIX_N_SIZE * TILE_K,
256+
true,
257+
TILE_K
258+
);
259+
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
260+
acc_sg_mat[m][n] = subgroupMatrixMultiplyAccumulate(src0_sg_mats[m], src1_sg_mat, acc_sg_mat[m][n]);
261+
}
256262
}
257263
}
258264
}
@@ -268,12 +274,14 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>,
268274
let tile_row_base_local = subgroup_n * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
269275
let tile_col_base_local = subgroup_m * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
270276

271-
for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
272-
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
273-
let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
274-
let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
275-
let out_base = local_row * WG_TILE_STRIDE + local_col;
276-
subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
277+
if (subgroup_id < EXPECTED_SUBGROUPS) { // 2-5% performance hit :(
278+
for (var n = 0u; n < SUBGROUP_MATRIX_N; n++) {
279+
for (var m = 0u; m < SUBGROUP_MATRIX_M; m++) {
280+
let local_row = tile_row_base_local + n * SUBGROUP_MATRIX_N_SIZE;
281+
let local_col = tile_col_base_local + m * SUBGROUP_MATRIX_M_SIZE;
282+
let out_base = local_row * WG_TILE_STRIDE + local_col;
283+
subgroupMatrixStore(&shmem, out_base, acc_sg_mat[m][n], true, WG_TILE_STRIDE);
284+
}
277285
}
278286
}
279287

0 commit comments

Comments
 (0)