diff --git a/dflash/src/dflash_graph.h b/dflash/src/dflash_graph.h index 304ff8e3..5b5204b6 100644 --- a/dflash/src/dflash_graph.h +++ b/dflash/src/dflash_graph.h @@ -1,6 +1,9 @@ // Shared inputs/outputs for the DFlash draft graph builder. #pragma once +#include +#include + #include "ggml.h" namespace dflash27b { @@ -13,11 +16,15 @@ struct DraftGraphInputs { ggml_tensor * target_hidden_cat;// [5*hidden, ctx_len, 1] f32 ggml_tensor * positions_q; // [q_len] i32 values [ctx_len..ctx_len+q_len-1] ggml_tensor * positions_k; // [ctx_len+q_len] i32 values [0..ctx_len+q_len-1] + // Optional SWA mask for long-context sliding-attention layers. + // Shape [kv_len, q_len] or padded [kv_pad, q_pad], type F16, values + // 0 for visible positions and -inf for masked positions. + ggml_tensor * attn_mask = nullptr; // Optional: if non-null, the graph projects final hidden states through // this LM head (shape [hidden, vocab]) and returns logits instead of // hidden states. Used for DFlash integration where the draft shares the // target's lm_head. - ggml_tensor * lm_head; + ggml_tensor * lm_head = nullptr; }; struct DraftGraphOutputs { @@ -30,4 +37,10 @@ DraftGraphOutputs build_draft_graph( const DraftWeights & w, const DraftGraphInputs & in); +bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len); +void build_draft_swa_mask(std::vector & out, + int ctx_len, + int q_len, + int swa_window); + } // namespace dflash27b diff --git a/dflash/src/qwen3_dflash_graph.cpp b/dflash/src/qwen3_dflash_graph.cpp index 638bd873..2060d00e 100644 --- a/dflash/src/qwen3_dflash_graph.cpp +++ b/dflash/src/qwen3_dflash_graph.cpp @@ -31,10 +31,46 @@ #include "internal.h" #include "dflash_graph.h" +#include #include +#include namespace dflash27b { +bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len) { + if (w.swa_window <= 0) { + return false; + } + const int total_k = ctx_len + DFLASH27B_DRAFT_BLOCK_SIZE; + if (total_k <= w.swa_window) { + return false; + } + for (int il = 0; il < w.n_layer; ++il) { + if (w.layers[il].is_swa) { + return true; + } + } + return false; +} + +void build_draft_swa_mask(std::vector & out, + int ctx_len, + int q_len, + int swa_window) { + static constexpr uint16_t F16_ZERO = 0x0000; + static constexpr uint16_t F16_NEG_INF = 0xFC00; + + const int total_k = ctx_len + q_len; + out.assign((size_t)total_k * q_len, F16_NEG_INF); + for (int q = 0; q < q_len; ++q) { + const int abs_q = ctx_len + q; + const int min_k = std::max(0, abs_q - swa_window); + for (int k = min_k; k < total_k; ++k) { + out[(size_t)q * total_k + k] = F16_ZERO; + } + } +} + DraftGraphOutputs build_draft_graph( ggml_context * ctx, const DraftWeights & w, @@ -118,8 +154,36 @@ DraftGraphOutputs build_draft_graph( V = ggml_cont (ctx, V); // ── 2f. Non-causal flash attention; GQA broadcast handled internally. + // For SWA layers (Qwen3.6 draft): apply sliding window mask + // limiting context K/V to the last `swa_window` positions. const float scale = 1.0f / std::sqrt((float)head_dim); - ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr, + ggml_tensor * attn_mask = nullptr; + if (L.is_swa && w.swa_window > 0 && total_k > w.swa_window) { + if (!in.attn_mask) { + set_last_error("build_draft_graph: SWA layer requires a non-null attn_mask"); + return {}; + } + if (in.attn_mask->type != GGML_TYPE_F16) { + char buf[128]; + std::snprintf(buf, sizeof(buf), + "build_draft_graph: SWA attn_mask must be F16, got %s", + ggml_type_name(in.attn_mask->type)); + set_last_error(buf); + return {}; + } + if (in.attn_mask->ne[0] < total_k || in.attn_mask->ne[1] < q_len) { + char buf[160]; + std::snprintf(buf, sizeof(buf), + "build_draft_graph: SWA attn_mask too small (%lld x %lld, need >= %d x %d)", + (long long)in.attn_mask->ne[0], + (long long)in.attn_mask->ne[1], + total_k, q_len); + set_last_error(buf); + return {}; + } + attn_mask = in.attn_mask; + } + ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, attn_mask, scale, /*max_bias=*/0.0f, /*logit_softcap=*/0.0f); // attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1] diff --git a/dflash/test/smoke_draft_graph.cpp b/dflash/test/smoke_draft_graph.cpp index 16672216..8a1fa02e 100644 --- a/dflash/test/smoke_draft_graph.cpp +++ b/dflash/test/smoke_draft_graph.cpp @@ -85,6 +85,12 @@ int main(int argc, char ** argv) { ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, fc_in, ctx_len, 1); ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, q_len); ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, ctx_len + q_len); + ggml_tensor * attn_mask = nullptr; + if (draft_graph_needs_swa_mask(w, ctx_len)) { + attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, ctx_len + q_len, q_len); + ggml_set_name(attn_mask, "draft_swa_mask"); + ggml_set_input(attn_mask); + } ggml_set_name(noise_embed, "noise_embed"); ggml_set_name(target_hid, "target_hidden_cat"); ggml_set_name(pos_q, "positions_q"); @@ -101,6 +107,7 @@ int main(int argc, char ** argv) { gi.target_hidden_cat = target_hid; gi.positions_q = pos_q; gi.positions_k = pos_k; + gi.attn_mask = attn_mask; DraftGraphOutputs go = build_draft_graph(gctx, w, gi); if (!go.hidden_states) { std::fprintf(stderr, "build_draft_graph returned null\n"); return 1; } @@ -141,6 +148,11 @@ int main(int argc, char ** argv) { for (int i = 0; i < ctx_len + q_len; i++) pk[i] = i; ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size()); } + if (attn_mask) { + std::vector mask; + build_draft_swa_mask(mask, ctx_len, q_len, w.swa_window); + ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size()); + } // ── 7. Compute auto status = ggml_backend_graph_compute(backend, gf); diff --git a/dflash/test/test_vs_oracle.cpp b/dflash/test/test_vs_oracle.cpp index b0c247d9..352139bc 100644 --- a/dflash/test/test_vs_oracle.cpp +++ b/dflash/test/test_vs_oracle.cpp @@ -117,6 +117,12 @@ int main(int argc, char ** argv) { ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, m.fc_in, m.ctx_len, 1); ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.q_len); ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.ctx_len + m.q_len); + ggml_tensor * attn_mask = nullptr; + if (draft_graph_needs_swa_mask(w, m.ctx_len)) { + attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, m.ctx_len + m.q_len, m.q_len); + ggml_set_name(attn_mask, "draft_swa_mask"); + ggml_set_input(attn_mask); + } ggml_set_name(noise_embed, "noise_embed"); ggml_set_name(target_hid, "target_hidden_cat"); ggml_set_name(pos_q, "positions_q"); @@ -132,6 +138,7 @@ int main(int argc, char ** argv) { gi.target_hidden_cat = target_hid; gi.positions_q = pos_q; gi.positions_k = pos_k; + gi.attn_mask = attn_mask; DraftGraphOutputs go = build_draft_graph(gctx, w, gi); if (!go.hidden_states) return 1; ggml_set_output(go.hidden_states); @@ -154,6 +161,11 @@ int main(int argc, char ** argv) { for (int i = 0; i < m.ctx_len + m.q_len; i++) pk[i] = i; ggml_backend_tensor_set(pos_q, pq.data(), 0, sizeof(int32_t) * pq.size()); ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size()); + if (attn_mask) { + std::vector mask; + build_draft_swa_mask(mask, m.ctx_len, m.q_len, w.swa_window); + ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size()); + } // Compute auto status = ggml_backend_graph_compute(backend, gf);