@@ -15,8 +15,8 @@ As an example we can look at [kernel.metal](../gpu/metal/src/kernel.metal) which
1515
1616using namespace metal ;
1717
18- kernel void simpleMultiply (const device float* input [[ buffer(0)]] ,
19- device float* output [[ buffer(1)]] ,
18+ kernel void simpleMultiply (const device float* input [[ buffer(0)]] ,
19+ device float* output [[ buffer(1)]] ,
2020 uint id [[ thread_position_in_grid]] ) {
2121 output[ id] = input[ id] * 2.0;
2222}
@@ -57,8 +57,8 @@ using `#[...]`.
5757
5858Lets look at the ` simpleMultiply ` kernel:
5959``` c++
60- kernel void simpleMultiply (const device float* input [[ buffer(0)]] ,
61- device float* output [[ buffer(1)]] ,
60+ kernel void simple_multiply (const device float* input [[ buffer(0)]] ,
61+ device float* output [[ buffer(1)]] ,
6262 uint id [[ thread_position_in_grid]] ) {
6363```
6464The `[[buffer(0)]]` specifies that this parameter is bound to buffer 0.
@@ -89,7 +89,8 @@ parameter is part of the method name and the additional parameters need to be na
8989like error above.
9090
9191### Exploration
92- In the example [ project] ( ../gpu/metal ) we have the kernel in [ kernel.metal] ( ../gpu/metal/src/kernel.metal ) . We compile this using ` metal ` :
92+ In the example [ project] ( ../gpu/metal ) we have the kernel in [ kernel.metal] ( ../gpu/metal/src/kernel.metal ) .
93+ We compile this using ` metal ` :
9394``` console
9495$ xcrun metal --help
9596OVERVIEW: clang LLVM compiler
@@ -107,7 +108,7 @@ OPTIONS:
107108Notice that this says clang and the output if very similiar to a normal llvm tool chain. So I'm
108109guessing that Metal is a frontend to clang.
109110
110- So we compile the kernal into an object file:
111+ So we compile the kernel into an object file:
111112``` console
112113$ xcrun metal -c src/kernel.metal -o kernel.air
113114```
@@ -201,6 +202,21 @@ functions).
201202library. When we call `makeFunction` that is similar to using `dlsym` to get a function pointer
202203from a dynamic library.
203204
205+ For additional details and a brush up on objective- c syntax see [simple.mm ](../ gpu/ metal/ src/ simple.mm ).
206+
207+ Lambdas/ closures in objective- c are done using blocks. For example:
208+ ```objc
209+ ^ (size_t iter) { ... }
210+ ```
211+
212+ ### MTLComputePipelineState
213+ This is what compiles the AIR to run on the GPU, and is similar to compiling from PTX to SASS in CUDA.
214+ This returned object is an optimized and ready- to- run GPU program.
215+
216+ ```objective- c
217+ id< MTLComputePipelineState> computePipelineState = [device newComputePipelineStateWithFunction: kernelFunction error: & error];
218+ ```
219+ This is an expensive operation and should be done once and reused.
204220
205221
206222### GGML_USE_METAL
@@ -236,3 +252,70 @@ function(ggml_add_backend backend)
236252 endif ()
237253endfunction ()
238254```
255+
256+ ### Adding an new operation to the metal backend
257+ Apart from implementing the actual operation in a metal kernel we also need to enable the
258+
259+ #### Add a new struct for the operation
260+ This is done by adding a new struct in ggml/ src/ ggml- metal/ ggml- metal- impl.h:
261+ ```c
262+ typedef struct {
263+ float repeat ;
264+ float freq;
265+ float present;
266+ int32_t n_vocab;
267+ } ggml_metal_kargs_penalties;
268+ ```
269+
270+ #### Add device support for the new operation
271+ This is done by adding a new case in ggml_metal_device_supports_op in
272+ ggml/ src/ ggml- metal/ ggml- metal- device.m:
273+ ```objc
274+ bool ggml_metal_device_supports_op (ggml_metal_device_t dev, const struct ggml_tensor * op) {
275+ ...
276+ switch (op-> op) {
277+ ...
278+ case GGML_OP_PENALTIES:
279+ return op-> src[0 ]-> type == GGML_TYPE_F32 && // logits
280+ op-> src[1 ]-> type == GGML_TYPE_I32 && // history
281+ op-> src[2 ]-> type == GGML_TYPE_I32; // n_history
282+ ...
283+ }
284+ ```
285+
286+ #### Add the operation
287+ First add the operation to the operations header file ggml/ src/ ggml- metal/ ggml- metal- op.h :
288+ ```c++
289+ int ggml_metal_op_penalties (ggml_metal_op_t ctx, int idx);
290+ ```
291+ And then add a case to the ggml_metal_op_encode_impl function in
292+ ```c++
293+ static int ggml_metal_op_encode_impl (ggml_metal_op_t ctx, int idx) {
294+ ...
295+ switch (node-> op) {
296+ ...
297+ case GGML_OP_PENALTIES:
298+ {
299+ n_fuse = ggml_metal_op_penalties (ctx, idx);
300+ } break ;
301+ ...
302+ }
303+ ```
304+ And we add this function to the same file:
305+ ```c++
306+ int ggml_metal_op_penalties (ggml_metal_op_t ctx, int idx) {
307+ ...
308+ }
309+ ```
310+ And the kernel itself is in ggml/ src/ ggml- metal/ ggml- metal.metal :
311+ ```metal
312+ kernel void kernel_penalties_f32 (
313+ constant ggml_metal_kargs_penalties & args,
314+ device const float * logits, // src[0] - logits to penalize
315+ device const int * history, // src[1] - token history
316+ device const int * n_history_ptr, // src[2] - number of valid tokens in history
317+ device float * dst, // output - penalized logits
318+ uint tpig[[thread_position_in_grid]]) {
319+ ...
320+ }
321+ ```
0 commit comments