diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 4443420132f2b..7f0e8c67f1325 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1005,8 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp->self_kq_mask, "KQ_mask", -1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1143,8 +1142,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con auto inp = std::make_unique(hparams, cparams); // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); - //cb(inp_kq_mask, "KQ_mask", -1); + inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->kq_mask); inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; @@ -1209,7 +1207,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1343,7 +1341,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->cross_kq_mask); inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; @@ -1457,7 +1455,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1471,7 +1469,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; diff --git a/src/llama-graph.h b/src/llama-graph.h index c8b74a14741e2..7bdf656768a0c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -228,8 +228,8 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { ggml_tensor * get_kq_mask() const { return kq_mask_cnv; } - ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch] - ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch] + ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1] + ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -257,8 +257,8 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -293,10 +293,10 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams; @@ -313,8 +313,8 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch] + ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; @@ -343,8 +343,8 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] const llama_hparams & hparams; const llama_cparams & cparams;