diff --git a/src/layer/vulkan/shader/gemm.comp b/src/layer/vulkan/shader/gemm.comp index 093374f1cac..66aba7ce531 100644 --- a/src/layer/vulkan/shader/gemm.comp +++ b/src/layer/vulkan/shader/gemm.comp @@ -1,9 +1,10 @@ -// Copyright 2023 Tencent +// Copyright 2025 Tencent // SPDX-License-Identifier: BSD-3-Clause #version 450 #define LOCAL_MEMORY_UNROLL_INCH 8 +#define STAGES 2 layout (constant_id = 0) const float alpha = 1.f; layout (constant_id = 1) const float beta = 1.f; @@ -43,12 +44,14 @@ layout (push_constant) uniform parameter } p; #if NCNN_shader_local_memory -shared lfp tmp_a[8][LOCAL_MEMORY_UNROLL_INCH][2]; -shared lfp tmp_b[8][LOCAL_MEMORY_UNROLL_INCH][2]; +// Double-buffer shared memory (ping-pong between stages) +shared lfp tmp_a[STAGES][LOCAL_MEMORY_UNROLL_INCH][LOCAL_MEMORY_UNROLL_INCH][2]; // [stage][MtileRows][Ktile][rows(2)] +shared lfp tmp_b[STAGES][LOCAL_MEMORY_UNROLL_INCH][LOCAL_MEMORY_UNROLL_INCH][2]; // [stage][NtileCols][Ktile][cols(2)] #endif void main() { + // Each invocation computes a 2x2 micro-tile: (gy,gy+1) x (gx,gx+1) int gx = int(gl_GlobalInvocationID.x) * 2; int gy = int(gl_GlobalInvocationID.y) * 2; int gz = int(gl_GlobalInvocationID.z); @@ -63,8 +66,8 @@ void main() afp sum2 = afp(0.f); afp sum3 = afp(0.f); + // Preload C with beta, honoring broadcast const int broadcast_type_C = constantC == 1 ? constant_broadcast_type_C : p.broadcast_type_C; - if (broadcast_type_C == 0) { sum0 = buffer_ld1(C_blob_data, 0); @@ -103,104 +106,130 @@ void main() #if NCNN_shader_local_memory const int NN = psc(K); - const int lx = int(gl_LocalInvocationID.x); + const int lx = int(gl_LocalInvocationID.x); // 0..LOCAL_MEMORY_UNROLL_INCH-1 const int ly = int(gl_LocalInvocationID.y); - int k = 0; - for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH) + int stage = 0; + + // Prologue: load first K-tile (size currLen) into stage 0 + int k_base = 0; + int currLen = min(LOCAL_MEMORY_UNROLL_INCH, NN - k_base); + + // Guarded loads to avoid OOB on tail + if (currLen > 0) { + if (lx < currLen) { if (transA == 1) { - const int ai = (k + lx) * p.A_hstep + gy; - tmp_a[ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai)); - tmp_a[ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + 1)); + // A^T: (K x M) laid out with stride A_hstep per K + const int ai0 = (k_base + lx) * p.A_hstep + gy; + tmp_a[stage][ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai0)); + tmp_a[stage][ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai0 + 1)); } else { - const int ai = gy * p.A_hstep + (k + lx); - tmp_a[ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai)); - tmp_a[ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + p.A_hstep)); + // A: (M x K) with row stride A_hstep + const int ai0 = gy * p.A_hstep + (k_base + lx); + tmp_a[stage][ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai0)); + tmp_a[stage][ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai0 + p.A_hstep)); } + } + if (ly < currLen) + { if (transB == 1) { - const int bi = gx * p.B_hstep + (k + ly); - tmp_b[lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi)); - tmp_b[lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + p.B_hstep)); + // B^T: (N x K) with row stride B_hstep per N + const int bi0 = gx * p.B_hstep + (k_base + ly); + tmp_b[stage][lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi0)); + tmp_b[stage][lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi0 + p.B_hstep)); } else { - const int bi = (k + ly) * p.B_hstep + gx; - tmp_b[lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi)); - tmp_b[lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + 1)); + // B: (K x N) with row stride B_hstep per K + const int bi0 = (k_base + ly) * p.B_hstep + gx; + tmp_b[stage][lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi0)); + tmp_b[stage][lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi0 + 1)); } } + } - barrier(); - - for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++) - { - afp a0 = lfp2afp(tmp_a[ly][k4][0]); - afp a1 = lfp2afp(tmp_a[ly][k4][1]); - - afp b0 = lfp2afp(tmp_b[lx][k4][0]); - afp b1 = lfp2afp(tmp_b[lx][k4][1]); + barrier(); - sum0 += a0 * b0; - sum1 += a0 * b1; - sum2 += a1 * b0; - sum3 += a1 * b1; - } + k_base += currLen; - barrier(); - } - - if (k < NN) + // Main loop: double-buffer like v8.comp + while (k_base < NN) { - const int remain = NN - k; + const int nextStage = stage ^ 1; + const int nextLen = min(LOCAL_MEMORY_UNROLL_INCH, NN - k_base); - if (lx < remain) + // Preload next tile into nextStage + if (lx < nextLen) { if (transA == 1) { - const int ai = (k + lx) * p.A_hstep + gy; - tmp_a[ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai)); - tmp_a[ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + 1)); + const int ai = (k_base + lx) * p.A_hstep + gy; + tmp_a[nextStage][ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai)); + tmp_a[nextStage][ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + 1)); } else { - const int ai = gy * p.A_hstep + (k + lx); - tmp_a[ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai)); - tmp_a[ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + p.A_hstep)); + const int ai = gy * p.A_hstep + (k_base + lx); + tmp_a[nextStage][ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai)); + tmp_a[nextStage][ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + p.A_hstep)); } } - if (ly < remain) + if (ly < nextLen) { if (transB == 1) { - const int bi = gx * p.B_hstep + (k + ly); - tmp_b[lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi)); - tmp_b[lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + p.B_hstep)); + const int bi = gx * p.B_hstep + (k_base + ly); + tmp_b[nextStage][lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi)); + tmp_b[nextStage][lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + p.B_hstep)); } else { - const int bi = (k + ly) * p.B_hstep + gx; - tmp_b[lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi)); - tmp_b[lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + 1)); + const int bi = (k_base + ly) * p.B_hstep + gx; + tmp_b[nextStage][lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi)); + tmp_b[nextStage][lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + 1)); } } - barrier(); + // Compute on current stage while next loads are in flight + for (int k4 = 0; k4 < currLen; k4++) + { + afp a0 = lfp2afp(tmp_a[stage][ly][k4][0]); + afp a1 = lfp2afp(tmp_a[stage][ly][k4][1]); + + afp b0 = lfp2afp(tmp_b[stage][lx][k4][0]); + afp b1 = lfp2afp(tmp_b[stage][lx][k4][1]); + + sum0 += a0 * b0; + sum1 += a0 * b1; + sum2 += a1 * b0; + sum3 += a1 * b1; + } + + barrier(); // switch buffers safely - for (int k4 = 0; k4 < remain; k4++) + stage = nextStage; + k_base += nextLen; + currLen = nextLen; + } + + // Epilogue: compute the final loaded tile + if (currLen > 0) + { + for (int k4 = 0; k4 < currLen; k4++) { - afp a0 = lfp2afp(tmp_a[ly][k4][0]); - afp a1 = lfp2afp(tmp_a[ly][k4][1]); + afp a0 = lfp2afp(tmp_a[stage][ly][k4][0]); + afp a1 = lfp2afp(tmp_a[stage][ly][k4][1]); - afp b0 = lfp2afp(tmp_b[lx][k4][0]); - afp b1 = lfp2afp(tmp_b[lx][k4][1]); + afp b0 = lfp2afp(tmp_b[stage][lx][k4][0]); + afp b1 = lfp2afp(tmp_b[stage][lx][k4][1]); sum0 += a0 * b0; sum1 += a0 * b1; @@ -209,6 +238,7 @@ void main() } } #else + // Fallback: no local memory path for (int k = 0; k < psc(K); k++) { afp a0; @@ -246,18 +276,20 @@ void main() sum2 += a1 * b0; sum3 += a1 * b1; } -#endif + #endif #if NCNN_shader_local_memory if (gx >= psc(N) || gy >= psc(M) || gz >= 1) return; #endif + // Scale by alpha sum0 *= afp(alpha); sum1 *= afp(alpha); sum2 *= afp(alpha); sum3 *= afp(alpha); + // Store output (with bounds checks and optional transpose) if (output_transpose == 1) { const int gi = gx * p.outhstep + gy;