@@ -147,7 +147,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
147
147
148
148
enum ggml_metal_kernel_type {
149
149
GGML_METAL_KERNEL_TYPE_ADD,
150
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
151
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
152
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
153
+ GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
150
154
GGML_METAL_KERNEL_TYPE_ADD_ROW,
155
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2,
156
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4,
157
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6,
158
+ GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8,
151
159
GGML_METAL_KERNEL_TYPE_SUB,
152
160
GGML_METAL_KERNEL_TYPE_SUB_ROW,
153
161
GGML_METAL_KERNEL_TYPE_MUL,
@@ -1129,7 +1137,15 @@ @implementation GGMLMetalClass
1129
1137
// simd_sum and simd_max requires MTLGPUFamilyApple7
1130
1138
1131
1139
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD, add, true );
1140
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true );
1141
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true );
1142
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true );
1143
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true );
1132
1144
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true );
1145
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2, add_row_fuse_2, true );
1146
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4, add_row_fuse_4, true );
1147
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6, add_row_fuse_6, true );
1148
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8, add_row_fuse_8, true );
1133
1149
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB, sub, true );
1134
1150
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SUB_ROW, sub_row, true );
1135
1151
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL, mul, true );
@@ -1875,7 +1891,22 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1875
1891
}
1876
1892
}
1877
1893
1878
- static bool ggml_metal_encode_node (
1894
+ static bool nodes_are_same (
1895
+ const struct ggml_tensor * a,
1896
+ const struct ggml_tensor * b) {
1897
+ return
1898
+ a->type == b->type &&
1899
+ a->ne [0 ] == b->ne [0 ] &&
1900
+ a->ne [1 ] == b->ne [1 ] &&
1901
+ a->ne [2 ] == b->ne [2 ] &&
1902
+ a->ne [3 ] == b->ne [3 ] &&
1903
+ a->nb [0 ] == b->nb [0 ] &&
1904
+ a->nb [1 ] == b->nb [1 ] &&
1905
+ a->nb [2 ] == b->nb [2 ] &&
1906
+ a->nb [3 ] == b->nb [3 ];
1907
+ }
1908
+
1909
+ static int ggml_metal_encode_node (
1879
1910
ggml_backend_t backend,
1880
1911
int idx,
1881
1912
id <MTLComputeCommandEncoder > encoder,
@@ -1885,7 +1916,12 @@ static bool ggml_metal_encode_node(
1885
1916
1886
1917
struct ggml_cgraph * gf = ctx->gf ;
1887
1918
1888
- struct ggml_tensor * node = ggml_graph_node (gf, idx);
1919
+ enum ggml_op ops[8 ];
1920
+
1921
+ struct ggml_tensor ** nodes = ggml_graph_nodes (gf);
1922
+ struct ggml_tensor * node = nodes[idx];
1923
+
1924
+ struct ggml_tensor ** fuse = nodes + idx + 1 ;
1889
1925
1890
1926
// GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
1891
1927
@@ -1895,7 +1931,7 @@ static bool ggml_metal_encode_node(
1895
1931
struct ggml_tensor * dst = node;
1896
1932
1897
1933
if (ggml_is_empty (dst)) {
1898
- return true ;
1934
+ return 1 ;
1899
1935
}
1900
1936
1901
1937
switch (dst->op ) {
@@ -1906,7 +1942,7 @@ static bool ggml_metal_encode_node(
1906
1942
case GGML_OP_PERMUTE:
1907
1943
{
1908
1944
// noop -> next node
1909
- } return true ;
1945
+ } return 1 ;
1910
1946
default :
1911
1947
{
1912
1948
} break ;
@@ -1973,7 +2009,9 @@ static bool ggml_metal_encode_node(
1973
2009
id <MTLBuffer > id_src2 = src2 ? ggml_metal_get_buffer (src2, &offs_src2) : nil ;
1974
2010
id <MTLBuffer > id_dst = dst ? ggml_metal_get_buffer (dst, &offs_dst) : nil ;
1975
2011
1976
- #if 0
2012
+ int n_fuse = 1 ;
2013
+
2014
+ #if 1
1977
2015
GGML_LOG_INFO (" %s : op - %s \n " , __func__, ggml_op_name (dst->op ));
1978
2016
if (src0) {
1979
2017
GGML_LOG_INFO (" %s : src0 - %4s [%5lld , %5lld , %5lld , %5lld ] [%5lld , %5lld , %5lld , %5lld ], %d , %s \n " , __func__, ggml_type_name (src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
@@ -2050,14 +2088,50 @@ static bool ggml_metal_encode_node(
2050
2088
2051
2089
id <MTLComputePipelineState > pipeline = nil ;
2052
2090
2091
+ {
2092
+ ops[0 ] = GGML_OP_ADD;
2093
+ ops[1 ] = GGML_OP_ADD;
2094
+ ops[2 ] = GGML_OP_ADD;
2095
+ ops[3 ] = GGML_OP_ADD;
2096
+ ops[4 ] = GGML_OP_ADD;
2097
+ ops[5 ] = GGML_OP_ADD;
2098
+ ops[6 ] = GGML_OP_ADD;
2099
+ ops[7 ] = GGML_OP_ADD;
2100
+
2101
+ for (n_fuse = 8 ; n_fuse > 1 ; --n_fuse) {
2102
+ if (n_fuse % 2 == 1 ) {
2103
+ continue ;
2104
+ }
2105
+ if (ggml_can_fuse (gf, idx, ops, n_fuse)) {
2106
+ if (nodes_are_same (node->src [1 ], fuse[0 ]->src [1 ]) &&
2107
+ nodes_are_same (node->src [1 ], fuse[n_fuse - 2 ]->src [1 ])) {
2108
+ break ;
2109
+ }
2110
+ }
2111
+ }
2112
+ }
2113
+
2053
2114
if (ggml_nelements (src1) == ne10 && ggml_is_contiguous (src1) && ne00 % 4 == 0 && ne10 % 4 == 0 ) {
2054
2115
GGML_ASSERT (ggml_is_contiguous (src0));
2055
2116
2056
2117
// src1 is a row
2057
2118
GGML_ASSERT (ne11 == 1 );
2058
2119
2059
2120
switch (dst->op ) {
2060
- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ; break ;
2121
+ case GGML_OP_ADD:
2122
+ {
2123
+ switch (n_fuse) {
2124
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_2].pipeline ; break ;
2125
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_4].pipeline ; break ;
2126
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_6].pipeline ; break ;
2127
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW_FUSE_8].pipeline ; break ;
2128
+ default :
2129
+ {
2130
+ GGML_ASSERT (n_fuse == 1 );
2131
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline ;
2132
+ }
2133
+ }
2134
+ } break ;
2061
2135
case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline ; break ;
2062
2136
case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline ; break ;
2063
2137
case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline ; break ;
@@ -2067,7 +2141,21 @@ static bool ggml_metal_encode_node(
2067
2141
bcast_row = true ;
2068
2142
} else {
2069
2143
switch (dst->op ) {
2070
- case GGML_OP_ADD: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
2144
+ case GGML_OP_ADD:
2145
+ {
2146
+ GGML_LOG_INFO (" XXXXXXXXXXXXXXXXXXXXXXXXX n_fuse = %d \n " , n_fuse);
2147
+ switch (n_fuse) {
2148
+ case 2 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline ; break ;
2149
+ case 4 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline ; break ;
2150
+ case 6 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline ; break ;
2151
+ case 8 : pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline ; break ;
2152
+ default :
2153
+ {
2154
+ GGML_ASSERT (n_fuse == 1 );
2155
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_ADD].pipeline ; break ;
2156
+ }
2157
+ }
2158
+ } break ;
2071
2159
case GGML_OP_SUB: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SUB].pipeline ; break ;
2072
2160
case GGML_OP_MUL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL].pipeline ; break ;
2073
2161
case GGML_OP_DIV: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_DIV].pipeline ; break ;
@@ -2107,7 +2195,16 @@ static bool ggml_metal_encode_node(
2107
2195
[encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
2108
2196
[encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2109
2197
[encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
2110
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
2198
+ for (int f = 0 ; f < n_fuse - 1 ; ++f) {
2199
+ id_src1 = ggml_metal_get_buffer (fuse[f]->src [1 ], &offs_src1);
2200
+
2201
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 3 + f];
2202
+
2203
+ if (f + 1 == n_fuse - 1 ) {
2204
+ id_dst = ggml_metal_get_buffer (fuse[f], &offs_dst);
2205
+ }
2206
+ }
2207
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 + n_fuse];
2111
2208
2112
2209
if (bcast_row) {
2113
2210
const int64_t n = ggml_nelements (dst)/4 ;
@@ -2674,7 +2771,7 @@ static bool ggml_metal_encode_node(
2674
2771
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
2675
2772
if (!h_src0) {
2676
2773
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
2677
- return false ;
2774
+ return 0 ;
2678
2775
}
2679
2776
2680
2777
offs_src0 = 0;
@@ -3550,7 +3647,7 @@ static bool ggml_metal_encode_node(
3550
3647
id <MTLBuffer > h_src1 = ggml_metal_mem_pool_alloc (mem_pool, s_src1);
3551
3648
if (!h_src1) {
3552
3649
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_src1);
3553
- return false ;
3650
+ return 0 ;
3554
3651
}
3555
3652
3556
3653
const int64_t neh0 = ne0;
@@ -3566,15 +3663,15 @@ static bool ggml_metal_encode_node(
3566
3663
id <MTLBuffer > h_dst = ggml_metal_mem_pool_alloc (mem_pool, s_dst);
3567
3664
if (!h_dst) {
3568
3665
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_dst);
3569
- return false ;
3666
+ return 0 ;
3570
3667
}
3571
3668
3572
3669
// tokens per expert
3573
3670
const size_t s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
3574
3671
id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
3575
3672
if (!h_tpe) {
3576
3673
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_tpe);
3577
- return false ;
3674
+ return 0 ;
3578
3675
}
3579
3676
3580
3677
// id map
@@ -3583,7 +3680,7 @@ static bool ggml_metal_encode_node(
3583
3680
id <MTLBuffer > h_ids = ggml_metal_mem_pool_alloc (mem_pool, s_ids);
3584
3681
if (!h_ids) {
3585
3682
GGML_LOG_ERROR (" %s : failed to allocate buffer from memory pool, size = %zu \n " , __func__, s_ids);
3586
- return false ;
3683
+ return 0 ;
3587
3684
}
3588
3685
3589
3686
{
@@ -5442,7 +5539,7 @@ static bool ggml_metal_encode_node(
5442
5539
}
5443
5540
}
5444
5541
5445
- return true ;
5542
+ return n_fuse ;
5446
5543
}
5447
5544
5448
5545
static enum ggml_status ggml_metal_graph_compute (
@@ -5948,20 +6045,22 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
5948
6045
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs [cb_idx].mem_pool ;
5949
6046
ggml_metal_mem_pool_reset (mem_pool);
5950
6047
5951
- for (int idx = node_start; idx < node_end; ++idx ) {
6048
+ for (int idx = node_start; idx < node_end;) {
5952
6049
if (should_capture) {
5953
6050
[encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
5954
6051
}
5955
6052
5956
- const bool res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
6053
+ const int res = ggml_metal_encode_node (backend, idx, encoder, mem_pool);
5957
6054
5958
6055
if (should_capture) {
5959
6056
[encoder popDebugGroup ];
5960
6057
}
5961
6058
5962
- if (! res) {
6059
+ if (res == 0 ) {
5963
6060
break ;
5964
6061
}
6062
+
6063
+ idx += res;
5965
6064
}
5966
6065
5967
6066
[encoder endEncoding ];
0 commit comments