diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 94c6a128..101a5d1f 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -864,6 +864,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] v = ggml_cast(ctx, v, GGML_TYPE_F16); + if (mask != nullptr) { + mask = ggml_transpose(ctx, mask); + + if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) { + LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]); + LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)); + mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0); + } + + mask = ggml_cast(ctx, mask, GGML_TYPE_F16); + } + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);