Skip to content

Commit f1eb1cb

Browse files
authored
kleidiai : fix work size and threads sync for fp16 (ggml-org#16246)
1 parent de41f2b commit f1eb1cb

File tree

2 files changed

+119
-72
lines changed

2 files changed

+119
-72
lines changed

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,9 +513,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
513513

514514
# Fetch KleidiAI sources:
515515
include(FetchContent)
516-
set(KLEIDIAI_COMMIT_TAG "v1.13.0")
516+
set(KLEIDIAI_COMMIT_TAG "v1.14.0")
517517
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
518-
set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1")
518+
set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc")
519519

520520
if (POLICY CMP0135)
521521
cmake_policy(SET CMP0135 NEW)
@@ -592,6 +592,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
592592
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
593593
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
594594
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
595+
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
595596
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
596597
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
597598
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)

ggml/src/ggml-cpu/kleidiai/kleidiai.cpp

Lines changed: 116 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,38 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
8787
return tensor->ne[dim];
8888
}
8989

90+
template <typename Variant, typename Ret, typename... Args, std::size_t... Is>
91+
constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) {
92+
using V = std::remove_reference_t<Variant>;
93+
return (std::is_invocable_r_v<
94+
Ret,
95+
std::variant_alternative_t<Is, V>,
96+
Args...> || ...);
97+
}
98+
99+
template <typename Variant, typename Ret, typename... Args>
100+
constexpr bool variant_any_invocable_v =
101+
variant_any_invocable_impl<Variant, Ret, Args...>(
102+
std::make_index_sequence<
103+
std::variant_size_v<std::remove_reference_t<Variant>>>{});
104+
90105
template<typename Ret, typename Variant, typename... Args>
91-
static Ret variant_call(const Variant & var, Args&&... args) {
92-
return std::visit([&](auto&& func) -> Ret {
93-
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
94-
return func(std::forward<Args>(args)...);
95-
} else {
96-
throw std::runtime_error("Invalid function type in variant_call");
97-
}
98-
}, var);
106+
static inline Ret variant_call(Variant && var, Args&&... args) {
107+
static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>,
108+
"No alternative in Variant is invocable with the provided arguments and return type.");
109+
110+
return std::visit(
111+
[&](auto && f) -> Ret {
112+
using F = std::decay_t<decltype(f)>;
113+
if constexpr (std::is_invocable_r_v<Ret, F, Args...>) {
114+
return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...);
115+
} else {
116+
GGML_ABORT("Invalid function type in variant_call");
117+
GGML_UNREACHABLE();
118+
}
119+
},
120+
std::forward<Variant>(var)
121+
);
99122
}
100123

101124
namespace ggml::cpu::kleidiai {
@@ -138,7 +161,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
138161
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
139162
size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
140163
} else if (kernels->rhs_type == GGML_TYPE_F16) {
141-
size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
164+
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
165+
const int64_t rhs_batch_size0 = op->src[0]->ne[2];
166+
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
167+
size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) +
142168
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
143169
k * n * sizeof(float) + n * sizeof(float);
144170
} else {
@@ -148,7 +174,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
148174
return true;
149175
}
150176

151-
152177
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
153178
if (dst->op == GGML_OP_MUL_MAT) {
154179
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
@@ -165,8 +190,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
165190
}
166191

167192
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
168-
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
169-
170193
const ggml_tensor * src0 = dst->src[0];
171194
const ggml_tensor * src1 = dst->src[1];
172195

@@ -175,7 +198,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
175198
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
176199
GGML_ASSERT(kernels);
177200

178-
bool is_gemv = src1->ne[1] == 1;
201+
const bool is_gemv = src1->ne[1] == 1;
179202
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
180203
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
181204
GGML_ASSERT(kernel);
@@ -185,27 +208,30 @@ class tensor_traits : public ggml::cpu::tensor_traits {
185208

186209
const int64_t lhs_batch_size0 = ne12;
187210
const int64_t rhs_batch_size0 = ne02;
188-
const int64_t batch_size = rhs_batch_size0;
211+
const int64_t batch_size = lhs_batch_size0;
189212

213+
GGML_ASSERT(rhs_batch_size0 > 0);
214+
GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0);
190215
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
191216

192-
const int64_t m = ne11 * r;
193-
const int64_t n = ne01;
194-
const int64_t k = ne00;
217+
const int64_t m_group = ne11;
218+
const int64_t m = m_group;
219+
const int64_t n = ne01;
220+
const int64_t k = ne00;
195221

196222
const size_t lhs_stride = src1->nb[1];
197223
const size_t rhs_stride = src0->nb[1];
198224
const size_t dst_stride = dst->nb[1];
199225

200-
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
201-
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
202-
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
203-
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
226+
const int64_t mr = (int64_t) kernel->get_mr();
227+
const int64_t nr = (int64_t) kernel->get_nr();
228+
const int64_t kr = (int64_t) kernel->get_kr();
229+
const int64_t sr = (int64_t) kernel->get_sr();
204230

205-
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
206-
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
207-
const size_t kxn_size = k * n * sizeof(float);
208-
const size_t bias_size = n * sizeof(float);
231+
const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
232+
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k);
233+
const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float);
234+
const size_t bias_size = (size_t)n * sizeof(float);
209235

210236
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
211237
GGML_ASSERT(wsize_required <= params->wsize);
@@ -216,82 +242,102 @@ class tensor_traits : public ggml::cpu::tensor_traits {
216242
uint8_t * bias = rhs_kxn + kxn_size;
217243

218244
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
219-
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
220-
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
221-
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
245+
const int64_t rhs_batch_idx = batch_idx / r;
246+
const uint8_t * rhs_batch_base = static_cast<const uint8_t *>(src0->data) + rhs_batch_idx * src0->nb[2];
247+
uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2];
222248

223-
// LHS packing
249+
// LHS packing (threaded over m, honoring mr alignment and KV groups)
224250
{
225251
const int64_t m_roundup_mr = kai_roundup(m, mr);
226252
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
227253

228254
if (ith < num_threads) {
229-
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
255+
const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr);
230256
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
231257

232-
const int64_t m_start = ith * num_m_per_thread0;
233-
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
258+
const int64_t m_start = ith * num_m_per_thread0;
259+
const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
260+
261+
// Base packed offset (aligned) and per-row stride in bytes
262+
const size_t base_packed_off = variant_call<size_t>(
263+
lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
264+
const size_t next_block_off = variant_call<size_t>(
265+
lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
266+
const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr;
267+
268+
int64_t remaining = m_count;
269+
int64_t cur = m_start;
270+
271+
while (remaining > 0) {
272+
const int64_t row_in_group = cur;
273+
const int64_t avail = m_group - row_in_group;
274+
const int64_t take = std::min(avail, remaining);
234275

235-
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
236-
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
276+
const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2];
277+
const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride;
278+
const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes;
279+
void * dst_ptr = lhs_packed + dst_off;
237280

238-
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
239-
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
281+
variant_call<void>(lhs_info->pack_func,
282+
(size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr,
283+
/*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr);
240284

241-
variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
285+
cur += take;
286+
remaining -= take;
287+
}
242288
}
243289
}
244290

245-
// RHS packing
246-
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
247-
// First thread to reach this point handles RHS packing
248-
memset(bias, 0, n * sizeof(float));
249-
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
250-
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
251-
252-
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
253-
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
291+
// RHS packing (single thread), then synchronize
292+
if (ith == 0) {
293+
memset(bias, 0, (size_t)n * sizeof(float));
294+
transpose_f32kxn_f16nxk((size_t)n, (size_t)k,
295+
reinterpret_cast<float *>(rhs_kxn),
296+
reinterpret_cast<const uint16_t *>(rhs_batch_base),
297+
rhs_stride);
298+
299+
variant_call<void>(kernels->rhs_info.pack_func,
300+
/*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr,
301+
/*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)),
302+
rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr);
254303
}
255304

256305
ggml_barrier(params->threadpool);
257306

258-
first_to_arrive.clear(std::memory_order_release);
259-
260-
// Perform the matmul
307+
// Matmul (threaded over n)
261308
{
262-
const int64_t m_to_process = m;
263-
const int64_t m_start = 0;
264-
265-
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
266-
int64_t num_threads = KAI_MIN(n / n_step, nth);
267-
if (num_threads <= 0) {
268-
num_threads = 1;
309+
const int64_t n_step = (int64_t) kernel->get_n_step();
310+
int64_t num_threads_n = KAI_MIN(n / n_step, nth);
311+
if (num_threads_n <= 0) {
312+
num_threads_n = 1;
269313
}
270314

271-
if (ith < num_threads) {
272-
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
273-
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
315+
if (ith < num_threads_n) {
316+
const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step);
317+
const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0;
274318

275319
const int64_t n_start = ith * num_n_per_thread0;
276-
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
320+
const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
277321

278-
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
279-
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
280-
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
322+
// LHS packed base at row 0 (consistent with packing above)
323+
const size_t lhs_packed_offset0 = variant_call<size_t>(
324+
lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr);
325+
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k);
326+
const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride);
281327

282-
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
328+
const void * lhs_ptr = lhs_packed + lhs_packed_offset0;
283329
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
284-
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
330+
float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset);
285331

286-
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
332+
variant_call<void>(kernel->run_kernel,
333+
(size_t)m, (size_t)n_to_process, (size_t)k,
334+
lhs_ptr, rhs_ptr,
335+
dst_ptr, dst_stride, sizeof(float),
336+
-FLT_MAX, FLT_MAX);
287337
}
288338
}
289339

290340
if (batch_idx != batch_size - 1) {
291-
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
292-
// the work data buffer (params->wdata) is used as temporary storage which means that only
293-
// a single batch can be processed at any given time. No barrier is needed for the last
294-
// batch since GGML inserts a barrier between the execution of every operator.
295341
ggml_barrier(params->threadpool);
296342
}
297343
}

0 commit comments

Comments
 (0)