Skip to content

Commit cd69689

Browse files
committed
1 parent a64a411 commit cd69689

File tree

2 files changed

+84
-162
lines changed

2 files changed

+84
-162
lines changed

ggml-metal.m

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@
7676
GGML_METAL_DECL_KERNEL(rms_norm);
7777
GGML_METAL_DECL_KERNEL(norm);
7878
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
79-
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
8079
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
8180
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
8281
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@@ -220,7 +219,6 @@ @implementation GGMLMetalClass
220219
GGML_METAL_ADD_KERNEL(rms_norm);
221220
GGML_METAL_ADD_KERNEL(norm);
222221
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
223-
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
224222
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
225223
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
226224
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@@ -286,7 +284,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
286284
GGML_METAL_DEL_KERNEL(rms_norm);
287285
GGML_METAL_DEL_KERNEL(norm);
288286
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
289-
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
290287
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
291288
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
292289
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@@ -871,11 +868,7 @@ void ggml_metal_graph_compute(
871868
{
872869
nth0 = 32;
873870
nth1 = 1;
874-
if (ne11 * ne12 < 4) {
875-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
876-
} else {
877-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
878-
}
871+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
879872
} break;
880873
case GGML_TYPE_Q4_0:
881874
{
@@ -927,8 +920,8 @@ void ggml_metal_graph_compute(
927920
GGML_ASSERT(ne02 == 1);
928921
GGML_ASSERT(ne12 == 1);
929922

930-
nth0 = 4; //1;
931-
nth1 = 8; //32;
923+
nth0 = 2;
924+
nth1 = 32;
932925
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
933926
} break;
934927
case GGML_TYPE_Q5_K:
@@ -976,12 +969,9 @@ void ggml_metal_graph_compute(
976969
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
977970

978971
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
979-
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
972+
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
980973
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
981974
}
982-
else if (src0t == GGML_TYPE_Q4_K) {
983-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
984-
}
985975
else if (src0t == GGML_TYPE_Q3_K) {
986976
#ifdef GGML_QKK_64
987977
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
@@ -995,8 +985,8 @@ void ggml_metal_graph_compute(
995985
else if (src0t == GGML_TYPE_Q6_K) {
996986
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
997987
} else {
998-
int64_t ny = (ne11 + 3)/4;
999-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
988+
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
989+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1000990
}
1001991
}
1002992
} break;
@@ -1223,4 +1213,4 @@ void ggml_metal_graph_compute(
12231213
}
12241214

12251215
}
1226-
}
1216+
}

0 commit comments

Comments
 (0)