Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion dflash/src/dflash_graph.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Shared inputs/outputs for the DFlash draft graph builder.
#pragma once

#include <cstdint>
#include <vector>

#include "ggml.h"

namespace dflash27b {
Expand All @@ -13,11 +16,15 @@ struct DraftGraphInputs {
ggml_tensor * target_hidden_cat;// [5*hidden, ctx_len, 1] f32
ggml_tensor * positions_q; // [q_len] i32 values [ctx_len..ctx_len+q_len-1]
ggml_tensor * positions_k; // [ctx_len+q_len] i32 values [0..ctx_len+q_len-1]
// Optional SWA mask for long-context sliding-attention layers.
// Shape [kv_len, q_len] or padded [kv_pad, q_pad], type F16, values
// 0 for visible positions and -inf for masked positions.
ggml_tensor * attn_mask = nullptr;
// Optional: if non-null, the graph projects final hidden states through
// this LM head (shape [hidden, vocab]) and returns logits instead of
// hidden states. Used for DFlash integration where the draft shares the
// target's lm_head.
ggml_tensor * lm_head;
ggml_tensor * lm_head = nullptr;
};

struct DraftGraphOutputs {
Expand All @@ -30,4 +37,10 @@ DraftGraphOutputs build_draft_graph(
const DraftWeights & w,
const DraftGraphInputs & in);

bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len);
void build_draft_swa_mask(std::vector<uint16_t> & out,
int ctx_len,
int q_len,
int swa_window);

} // namespace dflash27b
66 changes: 65 additions & 1 deletion dflash/src/qwen3_dflash_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,46 @@
#include "internal.h"
#include "dflash_graph.h"

#include <algorithm>
#include <cmath>
#include <cstdio>

namespace dflash27b {

bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len) {
if (w.swa_window <= 0) {
return false;
}
const int total_k = ctx_len + DFLASH27B_DRAFT_BLOCK_SIZE;
if (total_k <= w.swa_window) {
return false;
}
for (int il = 0; il < w.n_layer; ++il) {
if (w.layers[il].is_swa) {
return true;
}
}
return false;
}

void build_draft_swa_mask(std::vector<uint16_t> & out,
int ctx_len,
int q_len,
int swa_window) {
static constexpr uint16_t F16_ZERO = 0x0000;
static constexpr uint16_t F16_NEG_INF = 0xFC00;

const int total_k = ctx_len + q_len;
out.assign((size_t)total_k * q_len, F16_NEG_INF);
for (int q = 0; q < q_len; ++q) {
const int abs_q = ctx_len + q;
const int min_k = std::max(0, abs_q - swa_window);
for (int k = min_k; k < total_k; ++k) {
out[(size_t)q * total_k + k] = F16_ZERO;
}
}
}

DraftGraphOutputs build_draft_graph(
ggml_context * ctx,
const DraftWeights & w,
Expand Down Expand Up @@ -118,8 +154,36 @@ DraftGraphOutputs build_draft_graph(
V = ggml_cont (ctx, V);

// ── 2f. Non-causal flash attention; GQA broadcast handled internally.
// For SWA layers (Qwen3.6 draft): apply sliding window mask
// limiting context K/V to the last `swa_window` positions.
const float scale = 1.0f / std::sqrt((float)head_dim);
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr,
ggml_tensor * attn_mask = nullptr;
if (L.is_swa && w.swa_window > 0 && total_k > w.swa_window) {
if (!in.attn_mask) {
set_last_error("build_draft_graph: SWA layer requires a non-null attn_mask");
return {};
}
if (in.attn_mask->type != GGML_TYPE_F16) {
char buf[128];
std::snprintf(buf, sizeof(buf),
"build_draft_graph: SWA attn_mask must be F16, got %s",
ggml_type_name(in.attn_mask->type));
set_last_error(buf);
return {};
}
if (in.attn_mask->ne[0] < total_k || in.attn_mask->ne[1] < q_len) {
char buf[160];
std::snprintf(buf, sizeof(buf),
"build_draft_graph: SWA attn_mask too small (%lld x %lld, need >= %d x %d)",
(long long)in.attn_mask->ne[0],
(long long)in.attn_mask->ne[1],
total_k, q_len);
set_last_error(buf);
return {};
}
attn_mask = in.attn_mask;
}
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, attn_mask,
scale, /*max_bias=*/0.0f,
/*logit_softcap=*/0.0f);
// attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1]
Expand Down
12 changes: 12 additions & 0 deletions dflash/test/smoke_draft_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ int main(int argc, char ** argv) {
ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, fc_in, ctx_len, 1);
ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, q_len);
ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, ctx_len + q_len);
ggml_tensor * attn_mask = nullptr;
if (draft_graph_needs_swa_mask(w, ctx_len)) {
attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, ctx_len + q_len, q_len);
ggml_set_name(attn_mask, "draft_swa_mask");
ggml_set_input(attn_mask);
}
ggml_set_name(noise_embed, "noise_embed");
ggml_set_name(target_hid, "target_hidden_cat");
ggml_set_name(pos_q, "positions_q");
Expand All @@ -101,6 +107,7 @@ int main(int argc, char ** argv) {
gi.target_hidden_cat = target_hid;
gi.positions_q = pos_q;
gi.positions_k = pos_k;
gi.attn_mask = attn_mask;

DraftGraphOutputs go = build_draft_graph(gctx, w, gi);
if (!go.hidden_states) { std::fprintf(stderr, "build_draft_graph returned null\n"); return 1; }
Expand Down Expand Up @@ -141,6 +148,11 @@ int main(int argc, char ** argv) {
for (int i = 0; i < ctx_len + q_len; i++) pk[i] = i;
ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size());
}
if (attn_mask) {
std::vector<uint16_t> mask;
build_draft_swa_mask(mask, ctx_len, q_len, w.swa_window);
ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size());
}

// ── 7. Compute
auto status = ggml_backend_graph_compute(backend, gf);
Expand Down
12 changes: 12 additions & 0 deletions dflash/test/test_vs_oracle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ int main(int argc, char ** argv) {
ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, m.fc_in, m.ctx_len, 1);
ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.q_len);
ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.ctx_len + m.q_len);
ggml_tensor * attn_mask = nullptr;
if (draft_graph_needs_swa_mask(w, m.ctx_len)) {
attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, m.ctx_len + m.q_len, m.q_len);
ggml_set_name(attn_mask, "draft_swa_mask");
ggml_set_input(attn_mask);
}
ggml_set_name(noise_embed, "noise_embed");
ggml_set_name(target_hid, "target_hidden_cat");
ggml_set_name(pos_q, "positions_q");
Expand All @@ -132,6 +138,7 @@ int main(int argc, char ** argv) {
gi.target_hidden_cat = target_hid;
gi.positions_q = pos_q;
gi.positions_k = pos_k;
gi.attn_mask = attn_mask;
DraftGraphOutputs go = build_draft_graph(gctx, w, gi);
if (!go.hidden_states) return 1;
ggml_set_output(go.hidden_states);
Expand All @@ -154,6 +161,11 @@ int main(int argc, char ** argv) {
for (int i = 0; i < m.ctx_len + m.q_len; i++) pk[i] = i;
ggml_backend_tensor_set(pos_q, pq.data(), 0, sizeof(int32_t) * pq.size());
ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size());
if (attn_mask) {
std::vector<uint16_t> mask;
build_draft_swa_mask(mask, m.ctx_len, m.q_len, w.swa_window);
ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size());
}

// Compute
auto status = ggml_backend_graph_compute(backend, gf);
Expand Down