Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 91 additions & 59 deletions src/layer/vulkan/shader/gemm.comp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// Copyright 2023 Tencent
// Copyright 2025 Tencent
// SPDX-License-Identifier: BSD-3-Clause

#version 450

#define LOCAL_MEMORY_UNROLL_INCH 8
#define STAGES 2

layout (constant_id = 0) const float alpha = 1.f;
layout (constant_id = 1) const float beta = 1.f;
Expand Down Expand Up @@ -43,12 +44,14 @@ layout (push_constant) uniform parameter
} p;

#if NCNN_shader_local_memory
shared lfp tmp_a[8][LOCAL_MEMORY_UNROLL_INCH][2];
shared lfp tmp_b[8][LOCAL_MEMORY_UNROLL_INCH][2];
// Double-buffer shared memory (ping-pong between stages)
shared lfp tmp_a[STAGES][LOCAL_MEMORY_UNROLL_INCH][LOCAL_MEMORY_UNROLL_INCH][2]; // [stage][MtileRows][Ktile][rows(2)]
shared lfp tmp_b[STAGES][LOCAL_MEMORY_UNROLL_INCH][LOCAL_MEMORY_UNROLL_INCH][2]; // [stage][NtileCols][Ktile][cols(2)]
#endif

void main()
{
// Each invocation computes a 2x2 micro-tile: (gy,gy+1) x (gx,gx+1)
int gx = int(gl_GlobalInvocationID.x) * 2;
int gy = int(gl_GlobalInvocationID.y) * 2;
int gz = int(gl_GlobalInvocationID.z);
Expand All @@ -63,8 +66,8 @@ void main()
afp sum2 = afp(0.f);
afp sum3 = afp(0.f);

// Preload C with beta, honoring broadcast
const int broadcast_type_C = constantC == 1 ? constant_broadcast_type_C : p.broadcast_type_C;

if (broadcast_type_C == 0)
{
sum0 = buffer_ld1(C_blob_data, 0);
Expand Down Expand Up @@ -103,104 +106,130 @@ void main()
#if NCNN_shader_local_memory
const int NN = psc(K);

const int lx = int(gl_LocalInvocationID.x);
const int lx = int(gl_LocalInvocationID.x); // 0..LOCAL_MEMORY_UNROLL_INCH-1
const int ly = int(gl_LocalInvocationID.y);

int k = 0;
for (; k + (LOCAL_MEMORY_UNROLL_INCH - 1) < NN; k += LOCAL_MEMORY_UNROLL_INCH)
int stage = 0;

// Prologue: load first K-tile (size currLen) into stage 0
int k_base = 0;
int currLen = min(LOCAL_MEMORY_UNROLL_INCH, NN - k_base);

// Guarded loads to avoid OOB on tail
if (currLen > 0)
{
if (lx < currLen)
{
if (transA == 1)
{
const int ai = (k + lx) * p.A_hstep + gy;
tmp_a[ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai));
tmp_a[ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + 1));
// A^T: (K x M) laid out with stride A_hstep per K
const int ai0 = (k_base + lx) * p.A_hstep + gy;
tmp_a[stage][ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai0));
tmp_a[stage][ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai0 + 1));
}
else
{
const int ai = gy * p.A_hstep + (k + lx);
tmp_a[ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai));
tmp_a[ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + p.A_hstep));
// A: (M x K) with row stride A_hstep
const int ai0 = gy * p.A_hstep + (k_base + lx);
tmp_a[stage][ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai0));
tmp_a[stage][ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai0 + p.A_hstep));
}
}

if (ly < currLen)
{
if (transB == 1)
{
const int bi = gx * p.B_hstep + (k + ly);
tmp_b[lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi));
tmp_b[lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + p.B_hstep));
// B^T: (N x K) with row stride B_hstep per N
const int bi0 = gx * p.B_hstep + (k_base + ly);
tmp_b[stage][lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi0));
tmp_b[stage][lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi0 + p.B_hstep));
}
else
{
const int bi = (k + ly) * p.B_hstep + gx;
tmp_b[lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi));
tmp_b[lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + 1));
// B: (K x N) with row stride B_hstep per K
const int bi0 = (k_base + ly) * p.B_hstep + gx;
tmp_b[stage][lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi0));
tmp_b[stage][lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi0 + 1));
}
}
}

barrier();

for (int k4 = 0; k4 < LOCAL_MEMORY_UNROLL_INCH; k4++)
{
afp a0 = lfp2afp(tmp_a[ly][k4][0]);
afp a1 = lfp2afp(tmp_a[ly][k4][1]);

afp b0 = lfp2afp(tmp_b[lx][k4][0]);
afp b1 = lfp2afp(tmp_b[lx][k4][1]);
barrier();

sum0 += a0 * b0;
sum1 += a0 * b1;
sum2 += a1 * b0;
sum3 += a1 * b1;
}
k_base += currLen;

barrier();
}

if (k < NN)
// Main loop: double-buffer like v8.comp
while (k_base < NN)
{
const int remain = NN - k;
const int nextStage = stage ^ 1;
const int nextLen = min(LOCAL_MEMORY_UNROLL_INCH, NN - k_base);

if (lx < remain)
// Preload next tile into nextStage
if (lx < nextLen)
{
if (transA == 1)
{
const int ai = (k + lx) * p.A_hstep + gy;
tmp_a[ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai));
tmp_a[ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + 1));
const int ai = (k_base + lx) * p.A_hstep + gy;
tmp_a[nextStage][ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai));
tmp_a[nextStage][ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + 1));
}
else
{
const int ai = gy * p.A_hstep + (k + lx);
tmp_a[ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai));
tmp_a[ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + p.A_hstep));
const int ai = gy * p.A_hstep + (k_base + lx);
tmp_a[nextStage][ly][lx][0] = sfp2lfp(buffer_ld1(A_blob_data, ai));
tmp_a[nextStage][ly][lx][1] = sfp2lfp(buffer_ld1(A_blob_data, ai + p.A_hstep));
}
}

if (ly < remain)
if (ly < nextLen)
{
if (transB == 1)
{
const int bi = gx * p.B_hstep + (k + ly);
tmp_b[lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi));
tmp_b[lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + p.B_hstep));
const int bi = gx * p.B_hstep + (k_base + ly);
tmp_b[nextStage][lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi));
tmp_b[nextStage][lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + p.B_hstep));
}
else
{
const int bi = (k + ly) * p.B_hstep + gx;
tmp_b[lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi));
tmp_b[lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + 1));
const int bi = (k_base + ly) * p.B_hstep + gx;
tmp_b[nextStage][lx][ly][0] = sfp2lfp(buffer_ld1(B_blob_data, bi));
tmp_b[nextStage][lx][ly][1] = sfp2lfp(buffer_ld1(B_blob_data, bi + 1));
}
}

barrier();
// Compute on current stage while next loads are in flight
for (int k4 = 0; k4 < currLen; k4++)
{
afp a0 = lfp2afp(tmp_a[stage][ly][k4][0]);
afp a1 = lfp2afp(tmp_a[stage][ly][k4][1]);

afp b0 = lfp2afp(tmp_b[stage][lx][k4][0]);
afp b1 = lfp2afp(tmp_b[stage][lx][k4][1]);

sum0 += a0 * b0;
sum1 += a0 * b1;
sum2 += a1 * b0;
sum3 += a1 * b1;
}

barrier(); // switch buffers safely

for (int k4 = 0; k4 < remain; k4++)
stage = nextStage;
k_base += nextLen;
currLen = nextLen;
}

// Epilogue: compute the final loaded tile
if (currLen > 0)
{
for (int k4 = 0; k4 < currLen; k4++)
{
afp a0 = lfp2afp(tmp_a[ly][k4][0]);
afp a1 = lfp2afp(tmp_a[ly][k4][1]);
afp a0 = lfp2afp(tmp_a[stage][ly][k4][0]);
afp a1 = lfp2afp(tmp_a[stage][ly][k4][1]);

afp b0 = lfp2afp(tmp_b[lx][k4][0]);
afp b1 = lfp2afp(tmp_b[lx][k4][1]);
afp b0 = lfp2afp(tmp_b[stage][lx][k4][0]);
afp b1 = lfp2afp(tmp_b[stage][lx][k4][1]);

sum0 += a0 * b0;
sum1 += a0 * b1;
Expand All @@ -209,6 +238,7 @@ void main()
}
}
#else
// Fallback: no local memory path
for (int k = 0; k < psc(K); k++)
{
afp a0;
Expand Down Expand Up @@ -246,18 +276,20 @@ void main()
sum2 += a1 * b0;
sum3 += a1 * b1;
}
#endif
#endif

#if NCNN_shader_local_memory
if (gx >= psc(N) || gy >= psc(M) || gz >= 1)
return;
#endif

// Scale by alpha
sum0 *= afp(alpha);
sum1 *= afp(alpha);
sum2 *= afp(alpha);
sum3 *= afp(alpha);

// Store output (with bounds checks and optional transpose)
if (output_transpose == 1)
{
const int gi = gx * p.outhstep + gy;
Expand Down
Loading