Skip to content

Commit 23bc8a3

Browse files
committed
metal : fuse add
1 parent 576c82e commit 23bc8a3

File tree

5 files changed

+344
-22
lines changed

5 files changed

+344
-22
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 116 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147147

148148
enum ggml_metal_kernel_type {
149149
GGML_METAL_KERNEL_TYPE_ADD,
150+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
151+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
152+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
153+
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
150154
GGML_METAL_KERNEL_TYPE_ADD_ROW,
155+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
156+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
157+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
158+
GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
151159
GGML_METAL_KERNEL_TYPE_SUB,
152160
GGML_METAL_KERNEL_TYPE_SUB_ROW,
153161
GGML_METAL_KERNEL_TYPE_MUL,
@@ -1129,7 +1137,15 @@ @implementation GGMLMetalClass
11291137
// simd_sum and simd_max requires MTLGPUFamilyApple7
11301138

11311139
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
1140+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
1141+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
1142+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
1143+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
11321144
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
1145+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true);
1146+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true);
1147+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true);
1148+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true);
11331149
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
11341150
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true);
11351151
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
@@ -1875,7 +1891,22 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18751891
}
18761892
}
18771893

1878-
static bool ggml_metal_encode_node(
1894+
static bool nodes_are_same(
1895+
const struct ggml_tensor * a,
1896+
const struct ggml_tensor * b) {
1897+
return
1898+
a->type == b->type &&
1899+
a->ne[0] == b->ne[0] &&
1900+
a->ne[1] == b->ne[1] &&
1901+
a->ne[2] == b->ne[2] &&
1902+
a->ne[3] == b->ne[3] &&
1903+
a->nb[0] == b->nb[0] &&
1904+
a->nb[1] == b->nb[1] &&
1905+
a->nb[2] == b->nb[2] &&
1906+
a->nb[3] == b->nb[3];
1907+
}
1908+
1909+
static int ggml_metal_encode_node(
18791910
ggml_backend_t backend,
18801911
int idx,
18811912
id<MTLComputeCommandEncoder> encoder,
@@ -1885,7 +1916,12 @@ static bool ggml_metal_encode_node(
18851916

18861917
struct ggml_cgraph * gf = ctx->gf;
18871918

1888-
struct ggml_tensor * node = ggml_graph_node(gf, idx);
1919+
enum ggml_op ops[8];
1920+
1921+
struct ggml_tensor ** nodes = ggml_graph_nodes(gf);
1922+
struct ggml_tensor * node = nodes[idx];
1923+
1924+
struct ggml_tensor ** fuse = nodes + idx + 1;
18891925

18901926
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
18911927

@@ -1895,7 +1931,7 @@ static bool ggml_metal_encode_node(
18951931
struct ggml_tensor * dst = node;
18961932

18971933
if (ggml_is_empty(dst)) {
1898-
return true;
1934+
return 1;
18991935
}
19001936

19011937
switch (dst->op) {
@@ -1906,7 +1942,7 @@ static bool ggml_metal_encode_node(
19061942
case GGML_OP_PERMUTE:
19071943
{
19081944
// noop -> next node
1909-
} return true;
1945+
} return 1;
19101946
default:
19111947
{
19121948
} break;
@@ -1973,7 +2009,9 @@ static bool ggml_metal_encode_node(
19732009
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
19742010
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
19752011

1976-
#if 0
2012+
int n_fuse = 1;
2013+
2014+
#if 1
19772015
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
19782016
if (src0) {
19792017
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
@@ -2050,14 +2088,50 @@ static bool ggml_metal_encode_node(
20502088

20512089
id<MTLComputePipelineState> pipeline = nil;
20522090

2091+
{
2092+
ops[0] = GGML_OP_ADD;
2093+
ops[1] = GGML_OP_ADD;
2094+
ops[2] = GGML_OP_ADD;
2095+
ops[3] = GGML_OP_ADD;
2096+
ops[4] = GGML_OP_ADD;
2097+
ops[5] = GGML_OP_ADD;
2098+
ops[6] = GGML_OP_ADD;
2099+
ops[7] = GGML_OP_ADD;
2100+
2101+
for (n_fuse = 8; n_fuse > 1; --n_fuse) {
2102+
if (n_fuse % 2 == 1) {
2103+
continue;
2104+
}
2105+
if (ggml_can_fuse(gf, idx, ops, n_fuse)) {
2106+
if (nodes_are_same(node->src[1], fuse[0]->src[1]) &&
2107+
nodes_are_same(node->src[1], fuse[n_fuse - 2]->src[1])) {
2108+
break;
2109+
}
2110+
}
2111+
}
2112+
}
2113+
20532114
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
20542115
GGML_ASSERT(ggml_is_contiguous(src0));
20552116

20562117
// src1 is a row
20572118
GGML_ASSERT(ne11 == 1);
20582119

20592120
switch (dst->op) {
2060-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
2121+
case GGML_OP_ADD:
2122+
{
2123+
switch (n_fuse) {
2124+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline; break;
2125+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline; break;
2126+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline; break;
2127+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline; break;
2128+
default:
2129+
{
2130+
GGML_ASSERT(n_fuse == 1);
2131+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline;
2132+
}
2133+
}
2134+
} break;
20612135
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
20622136
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
20632137
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
@@ -2067,7 +2141,21 @@ static bool ggml_metal_encode_node(
20672141
bcast_row = true;
20682142
} else {
20692143
switch (dst->op) {
2070-
case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2144+
case GGML_OP_ADD:
2145+
{
2146+
GGML_LOG_INFO("XXXXXXXXXXXXXXXXXXXXXXXXX n_fuse = %d\n", n_fuse);
2147+
switch (n_fuse) {
2148+
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
2149+
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
2150+
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
2151+
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
2152+
default:
2153+
{
2154+
GGML_ASSERT(n_fuse == 1);
2155+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
2156+
}
2157+
}
2158+
} break;
20712159
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
20722160
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
20732161
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
@@ -2107,7 +2195,16 @@ static bool ggml_metal_encode_node(
21072195
[encoder setBytes:&args length:sizeof(args) atIndex:0];
21082196
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
21092197
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
2110-
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
2198+
for (int f = 0; f < n_fuse - 1; ++f) {
2199+
id_src1 = ggml_metal_get_buffer(fuse[f]->src[1], &offs_src1);
2200+
2201+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:3 + f];
2202+
2203+
if (f + 1 == n_fuse - 1) {
2204+
id_dst = ggml_metal_get_buffer(fuse[f], &offs_dst);
2205+
}
2206+
}
2207+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2 + n_fuse];
21112208

21122209
if (bcast_row) {
21132210
const int64_t n = ggml_nelements(dst)/4;
@@ -2674,7 +2771,7 @@ static bool ggml_metal_encode_node(
26742771
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
26752772
if (!h_src0) {
26762773
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677-
return false;
2774+
return 0;
26782775
}
26792776

26802777
offs_src0 = 0;
@@ -3550,7 +3647,7 @@ static bool ggml_metal_encode_node(
35503647
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
35513648
if (!h_src1) {
35523649
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
3553-
return false;
3650+
return 0;
35543651
}
35553652

35563653
const int64_t neh0 = ne0;
@@ -3566,15 +3663,15 @@ static bool ggml_metal_encode_node(
35663663
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
35673664
if (!h_dst) {
35683665
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
3569-
return false;
3666+
return 0;
35703667
}
35713668

35723669
// tokens per expert
35733670
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
35743671
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
35753672
if (!h_tpe) {
35763673
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
3577-
return false;
3674+
return 0;
35783675
}
35793676

35803677
// id map
@@ -3583,7 +3680,7 @@ static bool ggml_metal_encode_node(
35833680
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
35843681
if (!h_ids) {
35853682
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
3586-
return false;
3683+
return 0;
35873684
}
35883685

35893686
{
@@ -5442,7 +5539,7 @@ static bool ggml_metal_encode_node(
54425539
}
54435540
}
54445541

5445-
return true;
5542+
return n_fuse;
54465543
}
54475544

54485545
static enum ggml_status ggml_metal_graph_compute(
@@ -5948,20 +6045,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
59486045
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
59496046
ggml_metal_mem_pool_reset(mem_pool);
59506047

5951-
for (int idx = node_start; idx < node_end; ++idx) {
6048+
for (int idx = node_start; idx < node_end;) {
59526049
if (should_capture) {
59536050
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
59546051
}
59556052

5956-
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
6053+
const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
59576054

59586055
if (should_capture) {
59596056
[encoder popDebugGroup];
59606057
}
59616058

5962-
if (!res) {
6059+
if (res == 0) {
59636060
break;
59646061
}
6062+
6063+
idx += res;
59656064
}
59666065

59676066
[encoder endEncoding];

0 commit comments

Comments
 (0)