76
76
GGML_METAL_DECL_KERNEL (rms_norm);
77
77
GGML_METAL_DECL_KERNEL (norm);
78
78
GGML_METAL_DECL_KERNEL (mul_mat_f16_f32);
79
- GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_1row);
80
79
GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32);
81
80
GGML_METAL_DECL_KERNEL (mul_mat_q4_1_f32);
82
81
GGML_METAL_DECL_KERNEL (mul_mat_q8_0_f32);
@@ -220,7 +219,6 @@ @implementation GGMLMetalClass
220
219
GGML_METAL_ADD_KERNEL (rms_norm);
221
220
GGML_METAL_ADD_KERNEL (norm);
222
221
GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
223
- GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_1row);
224
222
GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
225
223
GGML_METAL_ADD_KERNEL (mul_mat_q4_1_f32);
226
224
GGML_METAL_ADD_KERNEL (mul_mat_q8_0_f32);
@@ -286,7 +284,6 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
286
284
GGML_METAL_DEL_KERNEL (rms_norm);
287
285
GGML_METAL_DEL_KERNEL (norm);
288
286
GGML_METAL_DEL_KERNEL (mul_mat_f16_f32);
289
- GGML_METAL_DEL_KERNEL (mul_mat_f16_f32_1row);
290
287
GGML_METAL_DEL_KERNEL (mul_mat_q4_0_f32);
291
288
GGML_METAL_DEL_KERNEL (mul_mat_q4_1_f32);
292
289
GGML_METAL_DEL_KERNEL (mul_mat_q8_0_f32);
@@ -871,11 +868,7 @@ void ggml_metal_graph_compute(
871
868
{
872
869
nth0 = 32 ;
873
870
nth1 = 1 ;
874
- if (ne11 * ne12 < 4 ) {
875
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_1row];
876
- } else {
877
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
878
- }
871
+ [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32];
879
872
} break ;
880
873
case GGML_TYPE_Q4_0:
881
874
{
@@ -927,8 +920,8 @@ void ggml_metal_graph_compute(
927
920
GGML_ASSERT (ne02 == 1 );
928
921
GGML_ASSERT (ne12 == 1 );
929
922
930
- nth0 = 4 ; // 1 ;
931
- nth1 = 8 ; // 32;
923
+ nth0 = 2 ;
924
+ nth1 = 32 ;
932
925
[encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_K_f32];
933
926
} break ;
934
927
case GGML_TYPE_Q5_K:
@@ -976,12 +969,9 @@ void ggml_metal_graph_compute(
976
969
[encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
977
970
978
971
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
979
- src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
972
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
980
973
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
981
974
}
982
- else if (src0t == GGML_TYPE_Q4_K) {
983
- [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 3 )/4 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
984
- }
985
975
else if (src0t == GGML_TYPE_Q3_K) {
986
976
#ifdef GGML_QKK_64
987
977
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
@@ -995,8 +985,8 @@ void ggml_metal_graph_compute(
995
985
else if (src0t == GGML_TYPE_Q6_K) {
996
986
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 1 )/2 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
997
987
} else {
998
- int64_t ny = (ne11 + 3 )/ 4 ;
999
- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ny , ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
988
+ [encoder setThreadgroupMemoryLength: nth0* sizeof ( float ) atIndex: 0 ] ;
989
+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne11 , ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1000
990
}
1001
991
}
1002
992
} break ;
@@ -1223,4 +1213,4 @@ void ggml_metal_graph_compute(
1223
1213
}
1224
1214
1225
1215
}
1226
- }
1216
+ }
0 commit comments