Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dflash/src/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ struct DraftLayer {
ggml_tensor * w_gate;
ggml_tensor * w_up;
ggml_tensor * w_down;
bool is_swa = false; // sliding window attention (Qwen3.6 draft)
};

struct DraftWeights {
Expand All @@ -175,6 +176,7 @@ struct DraftWeights {
int head_dim = DFLASH27B_TARGET_HEAD_DIM; // 128
int n_embd = DFLASH27B_TARGET_HIDDEN; // 5120
int n_ff = DFLASH27B_TARGET_INTERMEDIATE; // 17408
int swa_window = 0; // sliding window size (0 = full attention, 2048 for Qwen3.6 draft)
};

bool load_draft_safetensors(const std::string & path,
Expand Down
21 changes: 20 additions & 1 deletion dflash/src/qwen3_dflash_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,27 @@ DraftGraphOutputs build_draft_graph(
V = ggml_cont (ctx, V);

// ── 2f. Non-causal flash attention; GQA broadcast handled internally.
// For SWA layers (Qwen3.6 draft): apply sliding window mask
// limiting context K/V to the last `swa_window` positions.
const float scale = 1.0f / std::sqrt((float)head_dim);
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr,
ggml_tensor * attn_mask = nullptr;
if (L.is_swa && w.swa_window > 0 && total_k > w.swa_window) {
// Build a mask that blocks attention beyond the window.
// mask shape: [total_k, q_len] — element (k, q) = 0 (attend) or -inf (block)
// For SWA: each query at position p attends to K positions in [p - window, p + window]
// But in DFlash non-causal mode, queries are at positions [ctx_len..ctx_len+q_len-1]
// and keys span [0..total_k-1]. SWA means keys older than window are masked.
const int win = w.swa_window;
attn_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, total_k, q_len);
ggml_set_name(attn_mask, "swa_mask");
ggml_set_input(attn_mask);
// NOTE: mask data will be set at graph compute time by the caller.
// For now, we pass nullptr and let full attention run — the mask
// setup requires knowing absolute positions which are in `in.positions_k`.
// TODO: implement mask fill in the caller or use ggml_diag_mask_inf
attn_mask = nullptr; // fallback to full attention until mask fill is wired
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allocate something then immediately null it. I don't understand how this helps the performance?

}
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, attn_mask,
scale, /*max_bias=*/0.0f,
/*logit_softcap=*/0.0f);
// attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1]
Expand Down
62 changes: 62 additions & 0 deletions dflash/src/safetensors_draft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,68 @@ bool load_draft_safetensors(const std::string & path,
}
}

// ── 4b. Read config.json for SWA layer_types (Qwen3.6 draft) ──
{
// config.json sits next to model.safetensors
std::string dir;
auto slash = path.find_last_of('/');
if (slash != std::string::npos) {
dir = path.substr(0, slash);
} else {
dir = "."; // bare filename — look in CWD
}
std::string cfg_path = dir + "/config.json";
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
FILE * f = std::fopen(cfg_path.c_str(), "r");
if (f) {
std::fseek(f, 0, SEEK_END);
long flen = std::ftell(f);
std::fseek(f, 0, SEEK_SET);
std::string cfg(flen, '\0');
std::fread(&cfg[0], 1, flen, f);
std::fclose(f);

// Parse sliding_window
auto sw_pos = cfg.find("\"sliding_window\"");
if (sw_pos != std::string::npos) {
auto colon = cfg.find(':', sw_pos);
if (colon != std::string::npos) {
int sw = std::atoi(cfg.c_str() + colon + 1);
if (sw > 0) out.swa_window = sw;
}
}

// Parse layer_types array
auto lt_pos = cfg.find("\"layer_types\"");
if (lt_pos != std::string::npos) {
auto arr_start = cfg.find('[', lt_pos);
auto arr_end = cfg.find(']', arr_start);
if (arr_start != std::string::npos && arr_end != std::string::npos) {
std::string arr = cfg.substr(arr_start, arr_end - arr_start + 1);
int li = 0;
size_t search_pos = 0;
while (li < n_layers && search_pos < arr.size()) {
auto q1 = arr.find('"', search_pos);
if (q1 == std::string::npos) break;
auto q2 = arr.find('"', q1 + 1);
if (q2 == std::string::npos) break;
std::string lt = arr.substr(q1 + 1, q2 - q1 - 1);
out.layers[li].is_swa = (lt == "sliding_attention");
li++;
search_pos = q2 + 1;
}
}
}

int n_swa = 0;
for (int il = 0; il < n_layers; il++) {
if (out.layers[il].is_swa) n_swa++;
}
if (n_swa > 0) {
fprintf(stderr, "[draft] SWA layers: %d/%d (window=%d)\n", n_swa, n_layers, out.swa_window);
}
}
}

// ── 5. Allocate backend buffer, copy bytes ───────────────────
out.buf = ggml_backend_alloc_ctx_tensors(out.ctx, backend);
if (!out.buf) { set_last_error("ggml_backend_alloc_ctx_tensors failed (draft)"); return false; }
Expand Down