Skip to content

Commit ef513da

Browse files
committed
docs: add more details on Metal and GGML Metal backend
1 parent bc08769 commit ef513da

File tree

1 file changed

+89
-6
lines changed

1 file changed

+89
-6
lines changed

notes/metal.md

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ As an example we can look at [kernel.metal](../gpu/metal/src/kernel.metal) which
1515

1616
using namespace metal;
1717

18-
kernel void simpleMultiply(const device float* input [[buffer(0)]],
19-
device float* output [[buffer(1)]],
18+
kernel void simpleMultiply(const device float* input [[buffer(0)]],
19+
device float* output [[buffer(1)]],
2020
uint id [[thread_position_in_grid]]) {
2121
output[id] = input[id] * 2.0;
2222
}
@@ -57,8 +57,8 @@ using `#[...]`.
5757

5858
Lets look at the `simpleMultiply` kernel:
5959
```c++
60-
kernel void simpleMultiply(const device float* input [[buffer(0)]],
61-
device float* output [[buffer(1)]],
60+
kernel void simple_multiply(const device float* input [[buffer(0)]],
61+
device float* output [[buffer(1)]],
6262
uint id [[thread_position_in_grid]]) {
6363
```
6464
The `[[buffer(0)]]` specifies that this parameter is bound to buffer 0.
@@ -89,7 +89,8 @@ parameter is part of the method name and the additional parameters need to be na
8989
like error above.
9090

9191
### Exploration
92-
In the example [project](../gpu/metal) we have the kernel in [kernel.metal](../gpu/metal/src/kernel.metal). We compile this using `metal`:
92+
In the example [project](../gpu/metal) we have the kernel in [kernel.metal](../gpu/metal/src/kernel.metal).
93+
We compile this using `metal`:
9394
```console
9495
$ xcrun metal --help
9596
OVERVIEW: clang LLVM compiler
@@ -107,7 +108,7 @@ OPTIONS:
107108
Notice that this says clang and the output if very similiar to a normal llvm tool chain. So I'm
108109
guessing that Metal is a frontend to clang.
109110

110-
So we compile the kernal into an object file:
111+
So we compile the kernel into an object file:
111112
```console
112113
$ xcrun metal -c src/kernel.metal -o kernel.air
113114
```
@@ -201,6 +202,21 @@ functions).
201202
library. When we call `makeFunction` that is similar to using `dlsym` to get a function pointer
202203
from a dynamic library.
203204

205+
For additional details and a brush up on objective-c syntax see [simple.mm](../gpu/metal/src/simple.mm).
206+
207+
Lambdas/closures in objective-c are done using blocks. For example:
208+
```objc
209+
^(size_t iter) { ... }
210+
```
211+
212+
### MTLComputePipelineState
213+
This is what compiles the AIR to run on the GPU, and is similar to compiling from PTX to SASS in CUDA.
214+
This returned object is an optimized and ready-to-run GPU program.
215+
216+
```objective-c
217+
id<MTLComputePipelineState> computePipelineState = [device newComputePipelineStateWithFunction:kernelFunction error:&error];
218+
```
219+
This is an expensive operation and should be done once and reused.
204220

205221

206222
### GGML_USE_METAL
@@ -236,3 +252,70 @@ function(ggml_add_backend backend)
236252
endif()
237253
endfunction()
238254
```
255+
256+
### Adding an new operation to the metal backend
257+
Apart from implementing the actual operation in a metal kernel we also need to enable the
258+
259+
#### Add a new struct for the operation
260+
This is done by adding a new struct in ggml/src/ggml-metal/ggml-metal-impl.h:
261+
```c
262+
typedef struct {
263+
float repeat;
264+
float freq;
265+
float present;
266+
int32_t n_vocab;
267+
} ggml_metal_kargs_penalties;
268+
```
269+
270+
#### Add device support for the new operation
271+
This is done by adding a new case in ggml_metal_device_supports_op in
272+
ggml/src/ggml-metal/ggml-metal-device.m:
273+
```objc
274+
bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op) {
275+
...
276+
switch (op->op) {
277+
...
278+
case GGML_OP_PENALTIES:
279+
return op->src[0]->type == GGML_TYPE_F32 && // logits
280+
op->src[1]->type == GGML_TYPE_I32 && // history
281+
op->src[2]->type == GGML_TYPE_I32; // n_history
282+
...
283+
}
284+
```
285+
286+
#### Add the operation
287+
First add the operation to the operations header file ggml/src/ggml-metal/ggml-metal-op.h:
288+
```c++
289+
int ggml_metal_op_penalties (ggml_metal_op_t ctx, int idx);
290+
```
291+
And then add a case to the ggml_metal_op_encode_impl function in
292+
```c++
293+
static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
294+
...
295+
switch (node->op) {
296+
...
297+
case GGML_OP_PENALTIES:
298+
{
299+
n_fuse = ggml_metal_op_penalties(ctx, idx);
300+
} break;
301+
...
302+
}
303+
```
304+
And we add this function to the same file:
305+
```c++
306+
int ggml_metal_op_penalties(ggml_metal_op_t ctx, int idx) {
307+
...
308+
}
309+
```
310+
And the kernel itself is in ggml/src/ggml-metal/ggml-metal.metal:
311+
```metal
312+
kernel void kernel_penalties_f32(
313+
constant ggml_metal_kargs_penalties & args,
314+
device const float * logits, // src[0] - logits to penalize
315+
device const int * history, // src[1] - token history
316+
device const int * n_history_ptr, // src[2] - number of valid tokens in history
317+
device float * dst, // output - penalized logits
318+
uint tpig[[thread_position_in_grid]]) {
319+
...
320+
}
321+
```

0 commit comments

Comments
 (0)