diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 19f4d59e59747..248fa378ef0f1 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00); ggml_metal_kargs_sum_rows args = { @@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_rms_norm args = { @@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_l2_norm args = { @@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node( nth *= 2; } + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); nth = MIN(nth, ne00/4); ggml_metal_kargs_norm args = { @@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + // TODO: support + //const int32_t nk00 = ne00/ggml_blck_size(dst->type); + const int32_t nk00 = ne00; + + int nth = 32; // SIMD width + + while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) { + if (nth > nk00) { + nrptg = (nth + nk00 - 1)/nk00; + nth = nk00; + + if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) { + nrptg--; + } + } + } + + nth = MIN(nth, nk00); + ggml_metal_kargs_cpy args = { - /*.ne00 =*/ ne00, + /*.ne00 =*/ nk00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, /*.ne03 =*/ ne03, @@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); - - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)]; } break; case GGML_OP_SET: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3da19879b4b36..f028276068ef4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4306,11 +4306,16 @@ kernel void kernel_cpy( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { + ushort3 tptg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -4321,7 +4326,7 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { + for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; }