@@ -87,15 +87,38 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
87
87
return tensor->ne [dim];
88
88
}
89
89
90
+ template <typename Variant, typename Ret, typename ... Args, std::size_t ... Is>
91
+ constexpr bool variant_any_invocable_impl (std::index_sequence<Is...>) {
92
+ using V = std::remove_reference_t <Variant>;
93
+ return (std::is_invocable_r_v<
94
+ Ret,
95
+ std::variant_alternative_t <Is, V>,
96
+ Args...> || ...);
97
+ }
98
+
99
+ template <typename Variant, typename Ret, typename ... Args>
100
+ constexpr bool variant_any_invocable_v =
101
+ variant_any_invocable_impl<Variant, Ret, Args...>(
102
+ std::make_index_sequence<
103
+ std::variant_size_v<std::remove_reference_t <Variant>>>{});
104
+
90
105
template <typename Ret, typename Variant, typename ... Args>
91
- static Ret variant_call (const Variant & var, Args&&... args) {
92
- return std::visit ([&](auto && func) -> Ret {
93
- if constexpr (std::is_invocable_r_v<Ret, decltype (func), Args...>) {
94
- return func (std::forward<Args>(args)...);
95
- } else {
96
- throw std::runtime_error (" Invalid function type in variant_call" );
97
- }
98
- }, var);
106
+ static inline Ret variant_call (Variant && var, Args&&... args) {
107
+ static_assert (variant_any_invocable_v<std::remove_reference_t <Variant>, Ret, Args...>,
108
+ " No alternative in Variant is invocable with the provided arguments and return type." );
109
+
110
+ return std::visit (
111
+ [&](auto && f) -> Ret {
112
+ using F = std::decay_t <decltype (f)>;
113
+ if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
114
+ return std::invoke (std::forward<decltype (f)>(f), std::forward<Args>(args)...);
115
+ } else {
116
+ GGML_ABORT (" Invalid function type in variant_call" );
117
+ GGML_UNREACHABLE ();
118
+ }
119
+ },
120
+ std::forward<Variant>(var)
121
+ );
99
122
}
100
123
101
124
namespace ggml ::cpu::kleidiai {
@@ -138,7 +161,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
138
161
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
139
162
size = variant_call<size_t >(lhs_info->packed_size , m, k, QK4_0, mr, kr, sr);
140
163
} else if (kernels->rhs_type == GGML_TYPE_F16) {
141
- size = variant_call<size_t >(lhs_info->packed_size , m, k, mr, kr, sr) +
164
+ const int64_t lhs_batch_size0 = op->src [1 ]->ne [2 ];
165
+ const int64_t rhs_batch_size0 = op->src [0 ]->ne [2 ];
166
+ const int64_t r = lhs_batch_size0 / rhs_batch_size0;
167
+ size = variant_call<size_t >(lhs_info->packed_size , m * r, k, mr, kr, sr) +
142
168
variant_call<size_t >(kernels->rhs_info .packed_size , n, k) +
143
169
k * n * sizeof (float ) + n * sizeof (float );
144
170
} else {
@@ -148,7 +174,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
148
174
return true ;
149
175
}
150
176
151
-
152
177
bool compute_forward (struct ggml_compute_params * params, struct ggml_tensor * dst) override {
153
178
if (dst->op == GGML_OP_MUL_MAT) {
154
179
if (dst->src [0 ]->type == GGML_TYPE_Q4_0) {
@@ -165,8 +190,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
165
190
}
166
191
167
192
bool compute_forward_fp16 (ggml_compute_params * params, struct ggml_tensor * dst) {
168
- static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
169
-
170
193
const ggml_tensor * src0 = dst->src [0 ];
171
194
const ggml_tensor * src1 = dst->src [1 ];
172
195
@@ -175,7 +198,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
175
198
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels (ctx.features , dst);
176
199
GGML_ASSERT (kernels);
177
200
178
- bool is_gemv = src1->ne [1 ] == 1 ;
201
+ const bool is_gemv = src1->ne [1 ] == 1 ;
179
202
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm ;
180
203
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info ;
181
204
GGML_ASSERT (kernel);
@@ -185,27 +208,30 @@ class tensor_traits : public ggml::cpu::tensor_traits {
185
208
186
209
const int64_t lhs_batch_size0 = ne12;
187
210
const int64_t rhs_batch_size0 = ne02;
188
- const int64_t batch_size = rhs_batch_size0 ;
211
+ const int64_t batch_size = lhs_batch_size0 ;
189
212
213
+ GGML_ASSERT (rhs_batch_size0 > 0 );
214
+ GGML_ASSERT (lhs_batch_size0 % rhs_batch_size0 == 0 );
190
215
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
191
216
192
- const int64_t m = ne11 * r;
193
- const int64_t n = ne01;
194
- const int64_t k = ne00;
217
+ const int64_t m_group = ne11;
218
+ const int64_t m = m_group;
219
+ const int64_t n = ne01;
220
+ const int64_t k = ne00;
195
221
196
222
const size_t lhs_stride = src1->nb [1 ];
197
223
const size_t rhs_stride = src0->nb [1 ];
198
224
const size_t dst_stride = dst->nb [1 ];
199
225
200
- const int64_t mr = static_cast < int64_t >( kernel->get_mr () );
201
- const int64_t nr = static_cast < int64_t >( kernel->get_nr () );
202
- const int64_t kr = static_cast < int64_t >( kernel->get_kr () );
203
- const int64_t sr = static_cast < int64_t >( kernel->get_sr () );
226
+ const int64_t mr = ( int64_t ) kernel->get_mr ();
227
+ const int64_t nr = ( int64_t ) kernel->get_nr ();
228
+ const int64_t kr = ( int64_t ) kernel->get_kr ();
229
+ const int64_t sr = ( int64_t ) kernel->get_sr ();
204
230
205
- const size_t lhs_packed_size = variant_call<size_t >(lhs_info->packed_size , m, k, mr, kr, sr);
206
- const size_t rhs_packed_size = variant_call<size_t >(kernels->rhs_info .packed_size , n, k);
207
- const size_t kxn_size = k * n * sizeof (float );
208
- const size_t bias_size = n * sizeof (float );
231
+ const size_t lhs_packed_size = variant_call<size_t >(lhs_info->packed_size , ( size_t ) m, ( size_t ) k, ( size_t ) mr, ( size_t ) kr, ( size_t ) sr);
232
+ const size_t rhs_packed_size = variant_call<size_t >(kernels->rhs_info .packed_size , ( size_t ) n, ( size_t ) k);
233
+ const size_t kxn_size = ( size_t ) k * ( size_t ) n * sizeof (float );
234
+ const size_t bias_size = ( size_t ) n * sizeof (float );
209
235
210
236
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
211
237
GGML_ASSERT (wsize_required <= params->wsize );
@@ -216,82 +242,102 @@ class tensor_traits : public ggml::cpu::tensor_traits {
216
242
uint8_t * bias = rhs_kxn + kxn_size;
217
243
218
244
for (int64_t batch_idx = 0 ; batch_idx < batch_size; ++batch_idx) {
219
- const uint8_t * lhs_batch = static_cast < const uint8_t *>(src1-> data ) + batch_idx * m * lhs_stride ;
220
- const uint8_t * rhs_batch = static_cast <const uint8_t *>(src0->data ) + batch_idx * n * rhs_stride ;
221
- uint8_t * dst_batch = static_cast <uint8_t *>(dst->data ) + batch_idx * m * dst_stride ;
245
+ const int64_t rhs_batch_idx = batch_idx / r ;
246
+ const uint8_t * rhs_batch_base = static_cast <const uint8_t *>(src0->data ) + rhs_batch_idx * src0-> nb [ 2 ] ;
247
+ uint8_t * dst_batch_base = static_cast <uint8_t *>(dst->data ) + batch_idx * dst-> nb [ 2 ] ;
222
248
223
- // LHS packing
249
+ // LHS packing (threaded over m, honoring mr alignment and KV groups)
224
250
{
225
251
const int64_t m_roundup_mr = kai_roundup (m, mr);
226
252
const int64_t num_threads = KAI_MIN (m_roundup_mr / mr, nth);
227
253
228
254
if (ith < num_threads) {
229
- const int64_t num_m_per_thread0 = round_down (m_roundup_mr / num_threads, mr);
255
+ const int64_t num_m_per_thread0 = round_down (( size_t )( m_roundup_mr / num_threads), ( size_t ) mr);
230
256
const int64_t num_m_per_threadN_1 = m - (num_threads - 1 ) * num_m_per_thread0;
231
257
232
- const int64_t m_start = ith * num_m_per_thread0;
233
- const int64_t num_m_per_thread = (ith == num_threads - 1 ) ? num_m_per_threadN_1 : num_m_per_thread0;
258
+ const int64_t m_start = ith * num_m_per_thread0;
259
+ const int64_t m_count = (ith == num_threads - 1 ) ? num_m_per_threadN_1 : num_m_per_thread0;
260
+
261
+ // Base packed offset (aligned) and per-row stride in bytes
262
+ const size_t base_packed_off = variant_call<size_t >(
263
+ lhs_info->get_packed_offset , (size_t )m_start, (size_t )k, (size_t )mr, (size_t )kr, (size_t )sr);
264
+ const size_t next_block_off = variant_call<size_t >(
265
+ lhs_info->get_packed_offset , (size_t )(m_start + mr), (size_t )k, (size_t )mr, (size_t )kr, (size_t )sr);
266
+ const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t )mr;
267
+
268
+ int64_t remaining = m_count;
269
+ int64_t cur = m_start;
270
+
271
+ while (remaining > 0 ) {
272
+ const int64_t row_in_group = cur;
273
+ const int64_t avail = m_group - row_in_group;
274
+ const int64_t take = std::min (avail, remaining);
234
275
235
- const size_t lhs_offset = variant_call<size_t >(kernels->gemm .get_lhs_offset , m_start, lhs_stride);
236
- const size_t lhs_packed_offset = variant_call<size_t >(lhs_info->get_packed_offset , m_start, k, mr, kr, sr);
276
+ const uint8_t * lhs_batch_base = static_cast <const uint8_t *>(src1->data ) + batch_idx * src1->nb [2 ];
277
+ const void * src_ptr = lhs_batch_base + (size_t )row_in_group * lhs_stride;
278
+ const size_t dst_off = base_packed_off + (size_t )(cur - m_start) * row_stride_bytes;
279
+ void * dst_ptr = lhs_packed + dst_off;
237
280
238
- const void * src_ptr = static_cast <const uint8_t *>(lhs_batch) + lhs_offset;
239
- void * dst_ptr = static_cast <uint8_t *>(lhs_packed) + lhs_packed_offset;
281
+ variant_call<void >(lhs_info->pack_func ,
282
+ (size_t )take, (size_t )k, (size_t )mr, (size_t )kr, (size_t )sr,
283
+ /* m_idx_start*/ 0 , src_ptr, lhs_stride, dst_ptr);
240
284
241
- variant_call<void >(lhs_info->pack_func , num_m_per_thread, k, mr, kr, sr, 0 , src_ptr, lhs_stride, dst_ptr);
285
+ cur += take;
286
+ remaining -= take;
287
+ }
242
288
}
243
289
}
244
290
245
- // RHS packing
246
- if (first_to_arrive.test_and_set (std::memory_order_acquire) == false ) {
247
- // First thread to reach this point handles RHS packing
248
- memset (bias, 0 , n * sizeof (float ));
249
- transpose_f32kxn_f16nxk (n, k, reinterpret_cast <float *>(rhs_kxn),
250
- reinterpret_cast <const uint16_t *>(rhs_batch), rhs_stride);
251
-
252
- variant_call<void >(kernels->rhs_info .pack_func , 1 , n, k, nr, kr, sr, n * sizeof (float ),
253
- rhs_kxn, bias, nullptr , rhs_packed, 0 , nullptr );
291
+ // RHS packing (single thread), then synchronize
292
+ if (ith == 0 ) {
293
+ memset (bias, 0 , (size_t )n * sizeof (float ));
294
+ transpose_f32kxn_f16nxk ((size_t )n, (size_t )k,
295
+ reinterpret_cast <float *>(rhs_kxn),
296
+ reinterpret_cast <const uint16_t *>(rhs_batch_base),
297
+ rhs_stride);
298
+
299
+ variant_call<void >(kernels->rhs_info .pack_func ,
300
+ /* num_groups*/ 1 , (size_t )n, (size_t )k, (size_t )nr, (size_t )kr, (size_t )sr,
301
+ /* rhs_stride (bytes)*/ (size_t )(n * sizeof (float )),
302
+ rhs_kxn, bias, nullptr , rhs_packed, /* extra_bytes*/ 0 , /* params*/ nullptr );
254
303
}
255
304
256
305
ggml_barrier (params->threadpool );
257
306
258
- first_to_arrive.clear (std::memory_order_release);
259
-
260
- // Perform the matmul
307
+ // Matmul (threaded over n)
261
308
{
262
- const int64_t m_to_process = m;
263
- const int64_t m_start = 0 ;
264
-
265
- const int64_t n_step = static_cast <int64_t >(kernel->get_n_step ());
266
- int64_t num_threads = KAI_MIN (n / n_step, nth);
267
- if (num_threads <= 0 ) {
268
- num_threads = 1 ;
309
+ const int64_t n_step = (int64_t ) kernel->get_n_step ();
310
+ int64_t num_threads_n = KAI_MIN (n / n_step, nth);
311
+ if (num_threads_n <= 0 ) {
312
+ num_threads_n = 1 ;
269
313
}
270
314
271
- if (ith < num_threads ) {
272
- const int64_t num_n_per_thread0 = round_down (n / num_threads, n_step);
273
- const int64_t num_n_per_threadN_1 = n - (num_threads - 1 ) * num_n_per_thread0;
315
+ if (ith < num_threads_n ) {
316
+ const int64_t num_n_per_thread0 = round_down (( size_t )( n / num_threads_n), ( size_t ) n_step);
317
+ const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1 ) * num_n_per_thread0;
274
318
275
319
const int64_t n_start = ith * num_n_per_thread0;
276
- const int64_t n_to_process = (ith == num_threads - 1 ) ? num_n_per_threadN_1 : num_n_per_thread0;
320
+ const int64_t n_to_process = (ith == num_threads_n - 1 ) ? num_n_per_threadN_1 : num_n_per_thread0;
277
321
278
- const size_t lhs_packed_offset = variant_call<size_t >(kernel->get_lhs_offset , m_start, k);
279
- const size_t rhs_packed_offset = variant_call<size_t >(kernel->get_rhs_packed_offset , n_start, k);
280
- const size_t dst_offset = kernel->get_dst_offset (m_start, n_start, dst_stride);
322
+ // LHS packed base at row 0 (consistent with packing above)
323
+ const size_t lhs_packed_offset0 = variant_call<size_t >(
324
+ lhs_info->get_packed_offset , (size_t )0 , (size_t )k, (size_t )mr, (size_t )kr, (size_t )sr);
325
+ const size_t rhs_packed_offset = variant_call<size_t >(kernel->get_rhs_packed_offset , (size_t )n_start, (size_t )k);
326
+ const size_t dst_offset = kernel->get_dst_offset ((size_t )0 , (size_t )n_start, dst_stride);
281
327
282
- const void * lhs_ptr = lhs_packed + lhs_packed_offset ;
328
+ const void * lhs_ptr = lhs_packed + lhs_packed_offset0 ;
283
329
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
284
- float * dst_ptr = reinterpret_cast <float *>(dst_batch + dst_offset);
330
+ float * dst_ptr = reinterpret_cast <float *>(dst_batch_base + dst_offset);
285
331
286
- variant_call<void >(kernel->run_kernel , m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof (float ), -FLT_MAX, FLT_MAX);
332
+ variant_call<void >(kernel->run_kernel ,
333
+ (size_t )m, (size_t )n_to_process, (size_t )k,
334
+ lhs_ptr, rhs_ptr,
335
+ dst_ptr, dst_stride, sizeof (float ),
336
+ -FLT_MAX, FLT_MAX);
287
337
}
288
338
}
289
339
290
340
if (batch_idx != batch_size - 1 ) {
291
- // This barrier is necessary when the batch size is larger than 1. While processing a batch,
292
- // the work data buffer (params->wdata) is used as temporary storage which means that only
293
- // a single batch can be processed at any given time. No barrier is needed for the last
294
- // batch since GGML inserts a barrier between the execution of every operator.
295
341
ggml_barrier (params->threadpool );
296
342
}
297
343
}
0 commit comments