Skip to content

Commit 092c645

Browse files
committed
metal : add basic cumulative sum kernel example
1 parent b931493 commit 092c645

File tree

5 files changed

+192
-6
lines changed

5 files changed

+192
-6
lines changed

gpu/metal/.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
kernel.air
2-
kernel.metallib
2+
*.metallib
33
simple
44
simple-source
55
.build
6+
cumsum
7+
*.dSYM

gpu/metal/Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,9 @@ simple: src/simple.mm src/kernel.metal
1111
simple-source: src/simple-source.mm
1212
$(CC) $(CFLAGS) $< -o $@
1313

14+
cumsum: src/cumsum.mm src/cumsum.metal
15+
xcrun metal -c src/cumsum.metal -o - | xcrun metallib - -o cumsum.metallib
16+
$(CC) $(CFLAGS) src/cumsum.mm -o $@
17+
1418
clean:
15-
rm -f simple* *.air *.metallib
19+
${RM} -f simple* *.air *.metallib

gpu/metal/src/cumsum.metal

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <metal_stdlib>
2+
using namespace metal;
3+
4+
kernel void cumsum_scan(
5+
device const float * input [[buffer(0)]],
6+
device float * output [[buffer(1)]],
7+
constant uint& count [[buffer(2)]],
8+
constant uint& step [[buffer(3)]],
9+
uint gid [[thread_position_in_grid]]) {
10+
11+
if (gid >= count) return;
12+
13+
if (step == 0) {
14+
output[gid] = input[gid];
15+
} else {
16+
if (gid >= step) {
17+
output[gid] = input[gid] + input[gid - step];
18+
} else {
19+
output[gid] = input[gid];
20+
}
21+
}
22+
}

gpu/metal/src/cumsum.mm

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#import <Foundation/Foundation.h>
2+
#import <Metal/Metal.h>
3+
4+
int main(int argc, const char * argv[]) {
5+
@autoreleasepool {
6+
NSArray<id<MTLDevice>>* devices = MTLCopyAllDevices();
7+
id<MTLDevice> device = nil;
8+
for (id<MTLDevice> availableDevice in devices) {
9+
if ([availableDevice.name isEqualToString:@"Apple M3"]) {
10+
device = availableDevice;
11+
NSLog(@"Using Metal device: %@", device.name);
12+
break;
13+
}
14+
}
15+
if (!device) {
16+
NSLog(@"Apple M3 GPU is not available.");
17+
return -1;
18+
}
19+
20+
NSError *error = nil;
21+
NSString *libraryPath = @"cumsum.metallib";
22+
NSURL *libraryURL = [NSURL fileURLWithPath:libraryPath];
23+
id<MTLLibrary> defaultLibrary = [device newLibraryWithURL:libraryURL error:&error];
24+
if (!defaultLibrary) {
25+
NSLog(@"Failed to load the library. Error: %@", error.localizedDescription);
26+
return -1;
27+
}
28+
29+
NSLog(@"Functions in library:");
30+
for (NSString *name in defaultLibrary.functionNames) {
31+
NSLog(@"%@", name);
32+
}
33+
34+
id<MTLFunction> cumsum = [defaultLibrary newFunctionWithName:@"cumsum_scan"];
35+
if (!cumsum) {
36+
NSLog(@"Failed to find the kernel function.");
37+
return -1;
38+
}
39+
NSLog(@"Kernel function: %@", cumsum.name);
40+
41+
id<MTLComputePipelineState> computePipelineState = [device newComputePipelineStateWithFunction:cumsum error:&error];
42+
if (!computePipelineState) {
43+
NSLog(@"Failed to create compute pipeline state. Error: %@", error.localizedDescription);
44+
return -1;
45+
}
46+
47+
float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8};
48+
uint32_t count = sizeof(input_data) / sizeof(float);
49+
printf("Number of elements: %u\n", count);
50+
51+
// First buffer is the input buffer which we create with the input data.
52+
id<MTLBuffer> input_buffer = [device newBufferWithBytes:input_data
53+
length:count * sizeof(float)
54+
options:MTLResourceStorageModeShared];
55+
56+
// Notice that this buffer does not have any input data, only a length as this is the
57+
// output buffer.
58+
id<MTLBuffer> output_buffer = [device newBufferWithLength:count * sizeof(float)
59+
options:MTLResourceStorageModeShared];
60+
61+
// And then we have another buffer which is also created with a value, which is the
62+
// count, or number of elements to process.
63+
id<MTLBuffer> count_buffer = [device newBufferWithBytes:&count
64+
length:sizeof(uint32_t)
65+
options:MTLResourceStorageModeShared];
66+
67+
id<MTLCommandQueue> command_queue = [device newCommandQueue];
68+
69+
// Create pointer to the input and output buffers so that we can swap them later.
70+
id<MTLBuffer> src_buffer = input_buffer;
71+
id<MTLBuffer> dst_buffer = output_buffer;
72+
uint32_t num_passes = 0;
73+
// This is just left shifting 1 until it is >= count, so we start with
74+
// s=0, 1u << 0 = 1
75+
// s=1, 1u << 1 = 2
76+
// s=2, 1u << 2 = 4
77+
// s=3, 1u << 3 = 8
78+
for (uint32_t s = 0; (1u << s) < count; s++) {
79+
num_passes++;
80+
}
81+
// And we also need to include the initial copy of the first element.
82+
num_passes++;
83+
printf("Number of passes: %u\n", num_passes);
84+
85+
86+
// So we iterate passses time (currently 4 for 8 elements).
87+
for (uint32_t pass = 0; pass < num_passes; pass++) {
88+
uint32_t step = (pass == 0) ? 0 : (1u << (pass - 1));
89+
90+
// Recall that cumsum is the cumulative sum so the result from one pass will become
91+
// input to the second pass and so on.
92+
// pass = 0, just copy the input to the output:
93+
// [1 2 3 4 5 6 7 8]
94+
95+
// pass = 1, add each element with the element 1 position back:
96+
// [1 2 3 4 5 6 7 8]
97+
// [1 1+2 2+3 3+4 4+5 5+6 6+7 7+8]
98+
// [1 3 5 7 9 11 13 15]
99+
100+
// pass = 2, add each element with the element 2 position back:
101+
// [1 3 5 7 9 11 13 15 ]
102+
// [1 3 1+5 3+7 5+9 7+11 9+13 11+15]
103+
// [1 3 6 10 14 18 22 26 ]
104+
105+
// pass = 3, add each element with the element 4 position back:
106+
// [1 3 6 10 14 18 22 26 ]
107+
// [1 3 6 10 1+14 3+18 6+22 10+26]
108+
// [1 3 6 10 15 21 28 36 ]
109+
110+
// Sow we create a buffer with the current step value:
111+
id<MTLBuffer> step_buffer = [device newBufferWithBytes:&step
112+
length:sizeof(uint32_t)
113+
options:MTLResourceStorageModeShared];
114+
115+
id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer];
116+
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
117+
118+
// Set the pipeline state and the buffers:
119+
[encoder setComputePipelineState:computePipelineState];
120+
[encoder setBuffer:src_buffer offset:0 atIndex:0];
121+
[encoder setBuffer:dst_buffer offset:0 atIndex:1];
122+
[encoder setBuffer:count_buffer offset:0 atIndex:2];
123+
[encoder setBuffer:step_buffer offset:0 atIndex:3];
124+
125+
printf("Creating %d threads for pass %u with step %u\n", count, pass, step);
126+
// Will have 8 threads for each pass.
127+
MTLSize gridSize = MTLSizeMake(count, 1, 1);
128+
NSUInteger threadGroupSize = MIN(256, count);
129+
MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1);
130+
131+
[encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize];
132+
[encoder endEncoding];
133+
134+
[command_buffer commit];
135+
[command_buffer waitUntilCompleted];
136+
137+
// Swap the src and input buffers so that we keep using the latest result as input.
138+
id<MTLBuffer> temp = src_buffer;
139+
// Set src_buffer to the latest output buffer
140+
src_buffer = dst_buffer;
141+
dst_buffer = temp;
142+
}
143+
144+
// Result is in src_buffer (we swapped after last pass)
145+
float *result = (float *)[src_buffer contents];
146+
147+
printf("Input: ");
148+
for (uint32_t i = 0; i < count; i++) {
149+
printf("%2.0f ", input_data[i]);
150+
}
151+
printf("\nCumSum: ");
152+
for (uint32_t i = 0; i < count; i++) {
153+
printf("%2.0f ", result[i]);
154+
}
155+
printf("\n");
156+
}
157+
return 0;
158+
}

gpu/metal/src/kernel.metal

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#include <metal_stdlib>
22
using namespace metal;
33

4-
kernel void simple_multiply(const device float* input [[buffer(0)]],
5-
device float* output [[buffer(1)]],
6-
device int* debug_buffer [[buffer(2)]],
7-
constant int& some_constant [[buffer(3)]],
4+
kernel void simple_multiply(const device float* input [[buffer(0)]],
5+
device float* output [[buffer(1)]],
6+
device int* debug_buffer [[buffer(2)]],
7+
constant int& some_constant [[buffer(3)]],
88
uint id [[thread_position_in_grid]]) {
99
output[id] = input[id] * 2.0;
1010

0 commit comments

Comments
 (0)