Skip to content

Commit e54b394

Browse files
authored
CUDA/HIP: fix ssm_scan on devices where warp size is not 32 (#14196)
1 parent 2c2caa4 commit e54b394

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ __global__ void __launch_bounds__(splitD, 2)
1010
float * __restrict__ dst, const int64_t L) {
1111
GGML_UNUSED(src1_nb0);
1212
GGML_UNUSED(src2_nb0);
13+
14+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1315
const int bidx = blockIdx.x; // split along B
1416
const int bidy = blockIdx.y; // split along D
1517
const int tid = threadIdx.x;
@@ -44,16 +46,16 @@ __global__ void __launch_bounds__(splitD, 2)
4446
if (N == 16) {
4547
#pragma unroll
4648
for (size_t i = 0; i < splitD / 4; i += 2) {
47-
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
49+
float value = A_block[(wid * warp_size + i) * stride_A + wtid];
4850
// todo: bank conflict
4951
// I am always confused with how to use the swizzling method to solve
5052
// bank conflit. Hoping somebody can tell me.
51-
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
53+
smem_A[(wid * warp_size + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
5254
}
5355
#pragma unroll
5456
for (size_t i = 0; i < splitD / 4; i += 2) {
55-
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
56-
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
57+
float value = s0_block[(wid * warp_size + i) * stride_s0 + wtid];
58+
smem_s0[(wid * warp_size + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
5759
}
5860
}
5961

0 commit comments

Comments
 (0)