Problem
we currently hard-codes CHUNK_SIZE = 16 in shader.rs and metal_msm.rs, which forces the Metal kernels to create more columns than necessary. Benchmarks show the fixed size favors large inputs but penalizes small MSM workloads (e.g., < 65 536 points). The ZPrize 2023 WebGPU reference implementation switches to CHUNK_SIZE = 4 for inputs < 2^16, giving a lighter MSM configuration and markedly better performance for smaller datasets.
Reference:
https://github.com/z-prize/2023-entries/blob/6cc68aeb63071d90817aeff4b55b34444fae42a8/prize-2-msm-wasm/webgpu-only/tal-derei-koh-wei-jie/src/submission/submission.ts#L80
Details / Proposed Solution
-
Expose CHUNK_SIZE as a compile-time or runtime parameter
- Use
4 when num_points < 2**16; keep 16 otherwise.
-
Update dispatch logic in host code (Metal & WebGPU)
- Recalculate
num_columns, threads_per_threadgroup, and threadgroup_count based on the selected CHUNK_SIZE.
-
Audit shader code
- Replace hard-coded literals with constants passed via the push-constants / uniform buffer.
- Validate loop bounds to prevent infinite loops or out-of-bounds memory writes when
CHUNK_SIZE = 4.
Acceptance Criteria
Problem
we currently hard-codes
CHUNK_SIZE = 16inshader.rsandmetal_msm.rs, which forces the Metal kernels to create more columns than necessary. Benchmarks show the fixed size favors large inputs but penalizes small MSM workloads (e.g., < 65 536 points). The ZPrize 2023 WebGPU reference implementation switches toCHUNK_SIZE = 4for inputs < 2^16, giving a lighter MSM configuration and markedly better performance for smaller datasets.Reference:
https://github.com/z-prize/2023-entries/blob/6cc68aeb63071d90817aeff4b55b34444fae42a8/prize-2-msm-wasm/webgpu-only/tal-derei-koh-wei-jie/src/submission/submission.ts#L80
Details / Proposed Solution
Expose
CHUNK_SIZEas a compile-time or runtime parameter4whennum_points < 2**16; keep16otherwise.Update dispatch logic in host code (Metal & WebGPU)
num_columns,threads_per_threadgroup, andthreadgroup_countbased on the selectedCHUNK_SIZE.Audit shader code
CHUNK_SIZE = 4.Acceptance Criteria
cargo test --release test_metal_msm_pipeline -- --nocapturepasses for bothCHUNK_SIZE = 4and16.CHUNK_SIZE = 4shows ≥ 10 % speed-up vs the current master on an M-series GPU.