Skip to content

Commit 7951daa

Browse files
authored
Merge pull request #2 from Green-Sky/chroma_fa_fix
fix mask with flash attn
2 parents 4fdedd5 + 3238fe3 commit 7951daa

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

ggml_extend.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
864864
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
865865
v = ggml_cast(ctx, v, GGML_TYPE_F16);
866866

867+
if (mask != nullptr) {
868+
mask = ggml_transpose(ctx, mask);
869+
870+
if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) {
871+
LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]);
872+
LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD));
873+
mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0);
874+
}
875+
876+
mask = ggml_cast(ctx, mask, GGML_TYPE_F16);
877+
}
878+
867879
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
868880
ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32);
869881

0 commit comments

Comments
 (0)