|
| 1 | +#import <Foundation/Foundation.h> |
| 2 | +#import <Metal/Metal.h> |
| 3 | + |
| 4 | +int main(int argc, const char * argv[]) { |
| 5 | + @autoreleasepool { |
| 6 | + NSArray<id<MTLDevice>>* devices = MTLCopyAllDevices(); |
| 7 | + id<MTLDevice> device = nil; |
| 8 | + for (id<MTLDevice> availableDevice in devices) { |
| 9 | + if ([availableDevice.name isEqualToString:@"Apple M3"]) { |
| 10 | + device = availableDevice; |
| 11 | + NSLog(@"Using Metal device: %@", device.name); |
| 12 | + break; |
| 13 | + } |
| 14 | + } |
| 15 | + if (!device) { |
| 16 | + NSLog(@"Apple M3 GPU is not available."); |
| 17 | + return -1; |
| 18 | + } |
| 19 | + |
| 20 | + NSError *error = nil; |
| 21 | + NSString *libraryPath = @"cumsum.metallib"; |
| 22 | + NSURL *libraryURL = [NSURL fileURLWithPath:libraryPath]; |
| 23 | + id<MTLLibrary> defaultLibrary = [device newLibraryWithURL:libraryURL error:&error]; |
| 24 | + if (!defaultLibrary) { |
| 25 | + NSLog(@"Failed to load the library. Error: %@", error.localizedDescription); |
| 26 | + return -1; |
| 27 | + } |
| 28 | + |
| 29 | + NSLog(@"Functions in library:"); |
| 30 | + for (NSString *name in defaultLibrary.functionNames) { |
| 31 | + NSLog(@"%@", name); |
| 32 | + } |
| 33 | + |
| 34 | + id<MTLFunction> cumsum = [defaultLibrary newFunctionWithName:@"cumsum_scan"]; |
| 35 | + if (!cumsum) { |
| 36 | + NSLog(@"Failed to find the kernel function."); |
| 37 | + return -1; |
| 38 | + } |
| 39 | + NSLog(@"Kernel function: %@", cumsum.name); |
| 40 | + |
| 41 | + id<MTLComputePipelineState> computePipelineState = [device newComputePipelineStateWithFunction:cumsum error:&error]; |
| 42 | + if (!computePipelineState) { |
| 43 | + NSLog(@"Failed to create compute pipeline state. Error: %@", error.localizedDescription); |
| 44 | + return -1; |
| 45 | + } |
| 46 | + |
| 47 | + float input_data[] = {1, 2, 3, 4, 5, 6, 7, 8}; |
| 48 | + uint32_t count = sizeof(input_data) / sizeof(float); |
| 49 | + printf("Number of elements: %u\n", count); |
| 50 | + |
| 51 | + // First buffer is the input buffer which we create with the input data. |
| 52 | + id<MTLBuffer> input_buffer = [device newBufferWithBytes:input_data |
| 53 | + length:count * sizeof(float) |
| 54 | + options:MTLResourceStorageModeShared]; |
| 55 | + |
| 56 | + // Notice that this buffer does not have any input data, only a length as this is the |
| 57 | + // output buffer. |
| 58 | + id<MTLBuffer> output_buffer = [device newBufferWithLength:count * sizeof(float) |
| 59 | + options:MTLResourceStorageModeShared]; |
| 60 | + |
| 61 | + // And then we have another buffer which is also created with a value, which is the |
| 62 | + // count, or number of elements to process. |
| 63 | + id<MTLBuffer> count_buffer = [device newBufferWithBytes:&count |
| 64 | + length:sizeof(uint32_t) |
| 65 | + options:MTLResourceStorageModeShared]; |
| 66 | + |
| 67 | + id<MTLCommandQueue> command_queue = [device newCommandQueue]; |
| 68 | + |
| 69 | + // Create pointer to the input and output buffers so that we can swap them later. |
| 70 | + id<MTLBuffer> src_buffer = input_buffer; |
| 71 | + id<MTLBuffer> dst_buffer = output_buffer; |
| 72 | + uint32_t num_passes = 0; |
| 73 | + // This is just left shifting 1 until it is >= count, so we start with |
| 74 | + // s=0, 1u << 0 = 1 |
| 75 | + // s=1, 1u << 1 = 2 |
| 76 | + // s=2, 1u << 2 = 4 |
| 77 | + // s=3, 1u << 3 = 8 |
| 78 | + for (uint32_t s = 0; (1u << s) < count; s++) { |
| 79 | + num_passes++; |
| 80 | + } |
| 81 | + // And we also need to include the initial copy of the first element. |
| 82 | + num_passes++; |
| 83 | + printf("Number of passes: %u\n", num_passes); |
| 84 | + |
| 85 | + |
| 86 | + // So we iterate passses time (currently 4 for 8 elements). |
| 87 | + for (uint32_t pass = 0; pass < num_passes; pass++) { |
| 88 | + uint32_t step = (pass == 0) ? 0 : (1u << (pass - 1)); |
| 89 | + |
| 90 | + // Recall that cumsum is the cumulative sum so the result from one pass will become |
| 91 | + // input to the second pass and so on. |
| 92 | + // pass = 0, just copy the input to the output: |
| 93 | + // [1 2 3 4 5 6 7 8] |
| 94 | + |
| 95 | + // pass = 1, add each element with the element 1 position back: |
| 96 | + // [1 2 3 4 5 6 7 8] |
| 97 | + // [1 1+2 2+3 3+4 4+5 5+6 6+7 7+8] |
| 98 | + // [1 3 5 7 9 11 13 15] |
| 99 | + |
| 100 | + // pass = 2, add each element with the element 2 position back: |
| 101 | + // [1 3 5 7 9 11 13 15 ] |
| 102 | + // [1 3 1+5 3+7 5+9 7+11 9+13 11+15] |
| 103 | + // [1 3 6 10 14 18 22 26 ] |
| 104 | + |
| 105 | + // pass = 3, add each element with the element 4 position back: |
| 106 | + // [1 3 6 10 14 18 22 26 ] |
| 107 | + // [1 3 6 10 1+14 3+18 6+22 10+26] |
| 108 | + // [1 3 6 10 15 21 28 36 ] |
| 109 | + |
| 110 | + // Sow we create a buffer with the current step value: |
| 111 | + id<MTLBuffer> step_buffer = [device newBufferWithBytes:&step |
| 112 | + length:sizeof(uint32_t) |
| 113 | + options:MTLResourceStorageModeShared]; |
| 114 | + |
| 115 | + id<MTLCommandBuffer> command_buffer = [command_queue commandBuffer]; |
| 116 | + id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder]; |
| 117 | + |
| 118 | + // Set the pipeline state and the buffers: |
| 119 | + [encoder setComputePipelineState:computePipelineState]; |
| 120 | + [encoder setBuffer:src_buffer offset:0 atIndex:0]; |
| 121 | + [encoder setBuffer:dst_buffer offset:0 atIndex:1]; |
| 122 | + [encoder setBuffer:count_buffer offset:0 atIndex:2]; |
| 123 | + [encoder setBuffer:step_buffer offset:0 atIndex:3]; |
| 124 | + |
| 125 | + printf("Creating %d threads for pass %u with step %u\n", count, pass, step); |
| 126 | + // Will have 8 threads for each pass. |
| 127 | + MTLSize gridSize = MTLSizeMake(count, 1, 1); |
| 128 | + NSUInteger threadGroupSize = MIN(256, count); |
| 129 | + MTLSize threadgroupSize = MTLSizeMake(threadGroupSize, 1, 1); |
| 130 | + |
| 131 | + [encoder dispatchThreads:gridSize threadsPerThreadgroup:threadgroupSize]; |
| 132 | + [encoder endEncoding]; |
| 133 | + |
| 134 | + [command_buffer commit]; |
| 135 | + [command_buffer waitUntilCompleted]; |
| 136 | + |
| 137 | + // Swap the src and input buffers so that we keep using the latest result as input. |
| 138 | + id<MTLBuffer> temp = src_buffer; |
| 139 | + // Set src_buffer to the latest output buffer |
| 140 | + src_buffer = dst_buffer; |
| 141 | + dst_buffer = temp; |
| 142 | + } |
| 143 | + |
| 144 | + // Result is in src_buffer (we swapped after last pass) |
| 145 | + float *result = (float *)[src_buffer contents]; |
| 146 | + |
| 147 | + printf("Input: "); |
| 148 | + for (uint32_t i = 0; i < count; i++) { |
| 149 | + printf("%2.0f ", input_data[i]); |
| 150 | + } |
| 151 | + printf("\nCumSum: "); |
| 152 | + for (uint32_t i = 0; i < count; i++) { |
| 153 | + printf("%2.0f ", result[i]); |
| 154 | + } |
| 155 | + printf("\n"); |
| 156 | + } |
| 157 | + return 0; |
| 158 | +} |
0 commit comments