diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a92495dbb0070..57e77ea6f101d 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1980,15 +1980,16 @@ extern "C" { #define GGML_KQ_MASK_PAD 64 - // q: [n_embd_k, n_batch, n_head, ne3] - // k: [n_embd_k, n_kv, n_head_kv, ne3] - // v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !! - // mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! - // res: [n_embd_v, n_head, n_batch, ne3] !! permuted !! + // q: [n_embd_k, n_batch, n_head, ne3 ] + // k: [n_embd_k, n_kv, n_head_kv, ne3 ] + // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !! + // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !! // // broadcast: // n_head % n_head_kv == 0 - // ne3 % ne32 == 0 + // n_head % ne32 == 0 + // ne3 % ne33 == 0 // GGML_API struct ggml_tensor * ggml_flash_attn_ext( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2ae0721e9d171..11c180eec6567 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7799,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( memset(VKQ32, 0, DV*sizeof(float)); } - const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL; + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL; // k indices const int ik3 = iq3 / rk3; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0fa7fb953d815..f726f5838730e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3377,7 +3377,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return false; } // TODO: support broadcast - // ref: https://github.com/ggml-org/llama.cpp/pull/14435 + // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but + // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 if (op->src[0]->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8c2a971927414..f9679db3fee04 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -230,8 +230,10 @@ typedef struct { uint64_t nb22; uint64_t nb23; int32_t ne32; + int32_t ne33; uint64_t nb31; uint64_t nb32; + uint64_t nb33; int32_t ne1; int32_t ne2; float scale; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 1a0968ed4d231..25ce21349d746 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -4989,8 +4989,10 @@ static bool ggml_metal_encode_node( /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, /*.ne1 =*/ ne1, /*.ne2 =*/ ne2, /*.scale =*/ scale, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index dfff669726b94..a77d436970d74 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3784,7 +3784,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); const float m = pm[ic + tiisg]; @@ -4270,7 +4270,7 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32); + device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 25f70127a62c5..ccb0d47b6759b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -10265,6 +10265,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } + // TODO: support broadcast + // note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but + // the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505 + if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) { + return false; + } // It's straightforward to support different K/V dequant, but would // significantly increase the number of pipelines if (op->src[1]->type != op->src[2]->type) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f1245c50c9299..86996fff3d69a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3666,7 +3666,6 @@ static struct ggml_tensor * ggml_soft_max_impl( if (mask) { GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(ggml_is_3d(mask)); GGML_ASSERT(mask->ne[0] == a->ne[0]); GGML_ASSERT(mask->ne[1] >= a->ne[1]); GGML_ASSERT(a->ne[2]%mask->ne[2] == 0); @@ -4696,12 +4695,12 @@ struct ggml_tensor * ggml_flash_attn_ext( if (mask) { GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[2] == q->ne[3]); GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); - GGML_ASSERT(q->ne[3] % mask->ne[2] == 0); + GGML_ASSERT(q->ne[2] % mask->ne[2] == 0); + GGML_ASSERT(q->ne[3] % mask->ne[3] == 0); } if (max_bias > 0.0f) { diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a4f94c2594038..1ad9517a7ed48 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3607,7 +3607,7 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * m = nullptr; if (mask) { - m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1); + m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]); ggml_set_name(m, "m"); } @@ -4720,7 +4720,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias)); if (ne0 <= 32 && ne1 <= 32) { - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias)); } }