Skip to content

Commit 6034953

Browse files
committed
prefix sum
1 parent 5f8b522 commit 6034953

File tree

3 files changed

+245
-0
lines changed

3 files changed

+245
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ My solutions to CUDA challenges on https://leetgpu.com/
1515
[Rainbow Table](https://leetgpu.com/challenges/rainbow-table) | [Link](./rainbow_table.cu) | Easy |
1616
[Reduction](https://leetgpu.com/challenges/reduction) | [Link](./reduction.cu) | Medium |
1717
[Softmax](https://leetgpu.com/challenges/softmax) | [Link](./softmax.cu) | Medium |
18+
[Prefix Sum](https://leetgpu.com/challenges/prefix-sum) | [Link](./prefix_sum.cu) | Medium |
1819
[Dot Product](https://leetgpu.com/challenges/dot-product) | [Link](./dot_product.cu) | Medium |
1920
[Softmax Attention](https://leetgpu.com/challenges/softmax-attention) | [Link](./softmax_attention.cu) | Medium |
2021
[Password Cracking (FNV-1a)](https://leetgpu.com/challenges/password-cracking-fnv-1a) | [Link](./password_cracking_fnv_1a.cu) | Medium |

prefix_sum.cu

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#include "solve.h"
2+
#include <cuda_runtime.h>
3+
4+
#define FULL_MASK 0xffffffff
5+
6+
__device__ float store[1024*32];
7+
8+
__device__ float s1[1024], s2[1024];
9+
10+
template<bool store_value>
11+
__device__ void prefix_sum_compute(const float* input, float* output, int N, float* s) {
12+
int tid = threadIdx.x;
13+
int num_threads = blockDim.x;
14+
int block_id = blockIdx.x;
15+
int warp_id = tid / 32;
16+
int lane_id = tid % 32;
17+
18+
s[tid] = 0;
19+
__syncthreads();
20+
21+
int loop_bound = (N + 31);
22+
loop_bound -= (loop_bound % 32);
23+
for (int i = tid; i < loop_bound; i += num_threads) {
24+
float f = i < N ? input[i] : 0;
25+
// sum over warp
26+
for (int i = 16; i >= 1; i >>= 1) {
27+
f += __shfl_xor_sync(FULL_MASK, f, i);
28+
}
29+
// store the sum of these 32 values
30+
if (lane_id == 0) {
31+
s[i/32] = f;
32+
}
33+
}
34+
__syncthreads();
35+
36+
// up sweep
37+
int offset = 1;
38+
for (int d = 512; d > 0; d >>= 1) {
39+
__syncthreads();
40+
if (tid < d) {
41+
int a = (tid+1) * (offset * 2) - 1 - offset;
42+
int b = (tid+1) * (offset * 2) - 1;
43+
s[b] += s[a];
44+
}
45+
offset *= 2;
46+
}
47+
48+
// down sweep
49+
for (int d = 2; d < 1024; d *= 2) {
50+
offset >>= 1;
51+
__syncthreads();
52+
if (tid < d - 1) {
53+
int a = (tid+1) * offset - 1;
54+
int b = (tid+1) * offset - 1 + offset/2;
55+
s[b] += s[a];
56+
}
57+
}
58+
__syncthreads();
59+
60+
for (int i = tid; i < loop_bound; i += num_threads) {
61+
float f = i < N ? input[i] : 0;
62+
for (int d = 1; d <= 16; d *= 2) {
63+
float _f = __shfl_up_sync(FULL_MASK, f, d);
64+
if (lane_id - d >= 0) f += _f;
65+
}
66+
if (i < N) {
67+
if (i >= 32) {
68+
f += s[i/32 - 1];
69+
}
70+
output[i] = f;
71+
}
72+
}
73+
// for (int i = tid * 32; i < min(N, (tid+1)*32); i++) {
74+
// float ans = input[i];
75+
// if (i % 32 != 0) {
76+
// ans += output[i-1];
77+
// }
78+
// if (tid > 0) {
79+
// ans += s[i/32 - 1];
80+
// }
81+
// output[i] = ans;
82+
// }
83+
84+
if constexpr (store_value) {
85+
if (tid == 0) {
86+
store[block_id] = output[N-1];
87+
}
88+
}
89+
}
90+
91+
// template<bool store_value>
92+
// __device__ void prefix_sum_compute(const float* input, float* output, int N, float* s) {
93+
// int tid = threadIdx.x;
94+
// int block_id = blockIdx.x;
95+
// int start = tid * 32;
96+
// if (start < N) {
97+
// output[start] = input[start];
98+
// for (int i = start + 1; i < min(N, start + 32); i++) {
99+
// output[i] = output[i-1] + input[i];
100+
// }
101+
// }
102+
// __syncthreads();
103+
// if (tid == 0) {
104+
// for (int i = 32+31; i < N; i += 32) {
105+
// output[i] += output[i-32];
106+
// }
107+
// }
108+
// __syncthreads();
109+
110+
// for (int i = start; i < min(N, start + 31); i++) {
111+
// if (tid != 0) {
112+
// output[i] += output[start - 1];
113+
// }
114+
// }
115+
116+
// if constexpr (store_value) {
117+
// store[block_id] = output[N - 1];
118+
// }
119+
// }
120+
121+
// prefix sum small chunks of the overall array of size NUM_THREADS * 32.
122+
__global__ void prefix_sum_kernel1(const float* input, float* output, int N) {
123+
// extern __shared__ float s[]; // shared memory, size intended to be N block / 32
124+
125+
int num_per_block = blockDim.x * 32;
126+
int block_id = blockIdx.x;
127+
int N_this_block = min(num_per_block, N - num_per_block * block_id);
128+
prefix_sum_compute<true>(input + num_per_block * block_id, output + num_per_block * block_id, N_this_block, s1);
129+
130+
}
131+
132+
// prefix sum over store
133+
__global__ void prefix_sum_kernel2(int N_store) {
134+
extern __shared__ float s[]; // shared memory, size intended to be N / 32
135+
prefix_sum_compute<false>(store, store, N_store, s);
136+
}
137+
138+
139+
// add store's sums to each element
140+
__global__ void prefix_sum_kernel3(float* output, int N) {
141+
int tid = threadIdx.x;
142+
int block_id = blockIdx.x;
143+
int num_threads = blockDim.x;
144+
int num_per_block = num_threads * 32;
145+
int loop_end = min(N, num_per_block * (block_id + 1));
146+
// first block is already done
147+
if (block_id > 0) {
148+
int store_val = store[block_id - 1];
149+
for (int i = num_per_block * block_id + tid; i < loop_end; i += num_threads) {
150+
output[i] += store_val;
151+
}
152+
}
153+
}
154+
155+
// input, output are device pointers
156+
void solve(const float* input, float* output, int N) {
157+
int num_threads = 1024;
158+
int num_blocks = (N + (32*num_threads - 1)) / (32*num_threads);
159+
prefix_sum_kernel1<<<num_blocks, 1024>>>(input, output, N);
160+
prefix_sum_kernel2<<<1, 1024, num_threads * sizeof(float)>>>(num_blocks);
161+
prefix_sum_kernel3<<<num_blocks, 1024>>>(output, N);
162+
}

prefix_sum2.cu

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "solve.h"
2+
#include <cuda_runtime.h>
3+
4+
#define FULL_MASK 0xffffffff
5+
6+
__device__ float store1[512*512];
7+
__device__ float store2[512];
8+
9+
__device__ void prefix_sum_compute(const float* input, float* output, int N, float* s, float* storer) {
10+
int tid = threadIdx.x;
11+
int num_threads = blockDim.x;
12+
int block_id = blockIdx.x;
13+
14+
s[tid] = tid < N ? input[tid] : 0;
15+
__syncthreads();
16+
17+
// up sweep
18+
int offset = 1;
19+
for (int d = 512; d > 0; d >>= 1) {
20+
__syncthreads();
21+
if (tid < d) {
22+
int a = (tid+1) * (offset * 2) - 1 - offset;
23+
int b = (tid+1) * (offset * 2) - 1;
24+
s[b] += s[a];
25+
}
26+
offset *= 2;
27+
}
28+
29+
// down sweep
30+
for (int d = 2; d < 1024; d *= 2) {
31+
offset >>= 1;
32+
__syncthreads();
33+
if (tid < d - 1) {
34+
int a = (tid+1) * offset - 1;
35+
int b = (tid+1) * offset - 1 + offset/2;
36+
s[b] += s[a];
37+
}
38+
}
39+
__syncthreads();
40+
41+
if (tid < N) output[tid] = s[tid];
42+
43+
if (storer && tid == 0) {
44+
storer[block_id] = s[N-1];
45+
}
46+
}
47+
48+
__global__ void prefix_sum_kernel1(const float* input, float* output, int N) {
49+
extern __shared__ float s[];
50+
51+
int num_per_block = 512;
52+
int block_id = blockIdx.x;
53+
int N_this_block = min(num_per_block, N - num_per_block * block_id);
54+
prefix_sum_compute(input + num_per_block * block_id, output + num_per_block * block_id, N_this_block, s, store1);
55+
56+
}
57+
58+
// prefix sum over store
59+
__global__ void prefix_sum_kernel2(int N_store1) {
60+
extern __shared__ float s[]; // shared memory, size intended to be N / 32
61+
62+
int num_per_block = 512;
63+
int block_id = blockIdx.x;
64+
int N_this_block = min(num_per_block, N_store1 - num_per_block * block_id);
65+
prefix_sum_compute(store1 + num_per_block * block_id, store1 + num_per_block * block_id, N_this_block, s, store2);
66+
}
67+
68+
69+
// add store's sums to each element
70+
__global__ void prefix_sum_kernel3(float* output, int N) {
71+
// TODO
72+
}
73+
74+
// input, output are device pointers
75+
void solve(const float* input, float* output, int N) {
76+
int num_threads = 512;
77+
int num_blocks = (N + num_threads - 1) / num_threads;
78+
prefix_sum_kernel1<<<num_blocks, num_threads, num_threads * sizeof(float)>>>(input, output, N);
79+
int num_blocks2 = (num_blocks + num_threads - 1) / num_threads;
80+
prefix_sum_kernel2<<<num_blocks2, num_threads, num_threads * sizeof(float)>>>(num_blocks);
81+
prefix_sum_kernel3<<<num_blocks, 1024>>>(output, N);
82+
}

0 commit comments

Comments
 (0)