Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dflash/scripts/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ def main():
help="Sliding window for FA layers (KV positions). 0 = full "
"attention. Default 2048 (set in C++); only kicks in "
"once kv_cache > window.")
ap.add_argument("--draft-swa", type=int, default=None,
help="Draft SWA window (Qwen3.6 pattern: layers 0..n-2 use "
"sliding window). 0 = disabled. Typical value: 2048.")
ap.add_argument("--draft-ctx-max", type=int, default=None,
help="Draft context cap (default 2048). Raise to let the "
"full-attention layer see more context. Requires "
"--draft-swa < this value to activate SWA truncation.")
ap.add_argument("--max-ctx", type=int, default=0,
help="Override max KV context (default: auto-fit "
"prompt+n_gen+block, aligned to 256). Passing a "
Expand Down Expand Up @@ -130,6 +137,10 @@ def main():
env["DFLASH27B_KV_TQ3"] = "1"
if args.fa_window is not None:
env["DFLASH27B_FA_WINDOW"] = str(args.fa_window)
if args.draft_swa is not None:
env["DFLASH27B_DRAFT_SWA"] = str(args.draft_swa)
if args.draft_ctx_max is not None:
env["DFLASH27B_DRAFT_CTX_MAX"] = str(args.draft_ctx_max)

with tempfile.TemporaryDirectory() as tmp:
in_bin = os.path.join(tmp, "prompt.bin")
Expand Down
2 changes: 2 additions & 0 deletions dflash/src/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ struct DraftLayer {
ggml_tensor * w_gate;
ggml_tensor * w_up;
ggml_tensor * w_down;
bool is_swa = false; // true for SWA layers (Qwen3.6 pattern)
};

struct DraftWeights {
Expand All @@ -193,6 +194,7 @@ struct DraftWeights {
int head_dim = DFLASH27B_TARGET_HEAD_DIM; // 128
int n_embd = DFLASH27B_TARGET_HIDDEN; // 5120
int n_ff = DFLASH27B_TARGET_INTERMEDIATE; // 17408
int swa_window = 0; // sliding window size (0 = disabled)
};

bool load_draft_safetensors(const std::string & path,
Expand Down
49 changes: 33 additions & 16 deletions dflash/src/qwen3_dflash_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ DraftGraphOutputs build_draft_graph(

const int q_len = DFLASH27B_DRAFT_BLOCK_SIZE;
const int ctx_len = in.ctx_len;
const int total_k = ctx_len + q_len;
const int n_head = w.n_head;
const int n_kv = w.n_head_kv;
const int head_dim = w.head_dim;
const float eps = DFLASH27B_RMS_EPS;
const float rope_base = DFLASH27B_ROPE_THETA;
(void)ctx_len; // used only via input tensor shapes

// ── 1. Feature fusion: target_feat = rms_norm(fc @ target_hidden_cat, hidden_norm)
// fc: [5*hidden, hidden] (ggml: ne[0]=5*hidden, ne[1]=hidden)
Expand All @@ -65,6 +63,12 @@ DraftGraphOutputs build_draft_graph(
for (int il = 0; il < w.n_layer; il++) {
const DraftLayer & L = w.layers[il];

// ── SWA: determine effective context for this layer
const bool use_swa = L.is_swa && w.swa_window > 0 && ctx_len > w.swa_window;
const int eff_ctx = use_swa ? w.swa_window : ctx_len;
const int eff_total_k = eff_ctx + q_len;
const int ctx_offset = use_swa ? (ctx_len - w.swa_window) : 0;

// ── 2a. Attention pre-norm
ggml_tensor * hn = ggml_rms_norm(ctx, h, eps);
hn = ggml_mul(ctx, hn, L.attn_norm);
Expand All @@ -78,43 +82,56 @@ DraftGraphOutputs build_draft_graph(

// ── 2c. K and V from target_feat AND noise, then concat along sequence
// wk, wv: [hidden, kv_dim=1024]
ggml_tensor * Kctx = ggml_mul_mat(ctx, L.wk, target_feat); // [kv_dim, ctx_len, 1]
ggml_tensor * Kn = ggml_mul_mat(ctx, L.wk, hn); // [kv_dim, q_len, 1]
ggml_tensor * Vctx = ggml_mul_mat(ctx, L.wv, target_feat);
// For SWA layers: window target_feat to last swa_window positions.
ggml_tensor * tf = target_feat;
if (use_swa) {
tf = ggml_view_3d(ctx, target_feat,
w.n_embd, eff_ctx, 1,
target_feat->nb[1], target_feat->nb[2],
target_feat->nb[1] * ctx_offset);
}
ggml_tensor * Kctx = ggml_mul_mat(ctx, L.wk, tf); // [kv_dim, eff_ctx, 1]
ggml_tensor * Kn = ggml_mul_mat(ctx, L.wk, hn); // [kv_dim, q_len, 1]
ggml_tensor * Vctx = ggml_mul_mat(ctx, L.wv, tf);
ggml_tensor * Vn = ggml_mul_mat(ctx, L.wv, hn);

// concat along ne[1] (sequence) — ggml_concat second arg dim=1
ggml_tensor * K = ggml_concat(ctx, Kctx, Kn, 1); // [kv_dim, total_k, 1]
ggml_tensor * K = ggml_concat(ctx, Kctx, Kn, 1); // [kv_dim, eff_total_k, 1]
ggml_tensor * V = ggml_concat(ctx, Vctx, Vn, 1);

// Per-head k_norm
K = ggml_reshape_3d(ctx, K, head_dim, n_kv, total_k);
K = ggml_reshape_3d(ctx, K, head_dim, n_kv, eff_total_k);
K = ggml_rms_norm(ctx, K, eps);
K = ggml_mul (ctx, K, L.k_norm);

V = ggml_reshape_3d(ctx, V, head_dim, n_kv, total_k);
V = ggml_reshape_3d(ctx, V, head_dim, n_kv, eff_total_k);

// ── 2d. RoPE (NEOX, theta=10M)
// Q: positions_q [q_len] values [ctx_len..ctx_len+q_len-1]
// K: positions_k [total_k] values [0..total_k-1]
// Q: positions_q [q_len] values [ctx_len..ctx_len+q_len-1]
// K: positions_k [eff_total_k] — for SWA, starts from ctx_offset
ggml_tensor * pk = in.positions_k;
if (use_swa) {
pk = ggml_view_1d(ctx, in.positions_k, eff_total_k,
ctx_offset * ggml_element_size(in.positions_k));
}
Q = ggml_rope_ext(ctx, Q, in.positions_q, /*freq_factors=*/nullptr,
head_dim, GGML_ROPE_TYPE_NEOX, /*n_ctx_orig=*/0,
rope_base, /*freq_scale=*/1.0f,
/*ext_factor=*/0.0f, /*attn_factor=*/1.0f,
/*beta_fast=*/0.0f, /*beta_slow=*/0.0f);
K = ggml_rope_ext(ctx, K, in.positions_k, nullptr,
K = ggml_rope_ext(ctx, K, pk, nullptr,
head_dim, GGML_ROPE_TYPE_NEOX, 0,
rope_base, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);

// ── 2e. Permute into the layout flash_attn_ext wants
// q: [n_embd_k=head_dim, n_batch=q_len, n_head, ne3]
// k: [n_embd_k=head_dim, n_kv=total_k, n_head_kv, ne3]
// v: [n_embd_v=head_dim, n_kv=total_k, n_head_kv, ne3] (not transposed)
Q = ggml_permute(ctx, Q, 0, 2, 1, 3); // [head_dim, q_len, n_head, 1]
// k: [n_embd_k=head_dim, n_kv=eff_total_k, n_head_kv, ne3]
// v: [n_embd_v=head_dim, n_kv=eff_total_k, n_head_kv, ne3] (not transposed)
Q = ggml_permute(ctx, Q, 0, 2, 1, 3); // [head_dim, q_len, n_head, 1]
Q = ggml_cont (ctx, Q);
K = ggml_permute(ctx, K, 0, 2, 1, 3); // [head_dim, total_k, n_kv, 1]
K = ggml_permute(ctx, K, 0, 2, 1, 3); // [head_dim, eff_total_k, n_kv, 1]
K = ggml_cont (ctx, K);
V = ggml_permute(ctx, V, 0, 2, 1, 3); // [head_dim, total_k, n_kv, 1]
V = ggml_permute(ctx, V, 0, 2, 1, 3); // [head_dim, eff_total_k, n_kv, 1]
V = ggml_cont (ctx, V);

// ── 2f. Non-causal flash attention; GQA broadcast handled internally.
Expand Down
53 changes: 49 additions & 4 deletions dflash/test/test_dflash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ static constexpr int KQ_MASK_PAD = 32;
static int g_kq_stride_pad = KQ_MASK_PAD; // overridden to 256 when TBQ KV is active
static int g_max_ctx_override = 0; // overridden by --max-ctx=N (default 4096)
static int g_fa_window = 2048; // overridden by DFLASH27B_FA_WINDOW=N
static int g_draft_swa_window = 0; // draft SWA window (0 = disabled); --draft-swa=N
static int g_draft_ctx_max = 4096; // draft context cap; --draft-ctx-max=N
static int align_up(int x, int a) { return ((x + a - 1) / a) * a; }

// F16 encoding for the two values we use: 0 and -inf.
Expand Down Expand Up @@ -1551,7 +1553,8 @@ static bool run_target_layer_split_dflash_decode(
}

constexpr int DRAFT_CTX_MAX = 2048;
const int draft_ctx = std::min(committed, std::min(feature_ring.cap, DRAFT_CTX_MAX));
const int draft_ctx = std::min(committed, std::min(feature_ring.cap,
std::max(DRAFT_CTX_MAX, g_draft_ctx_max)));
const int draft_start = committed - draft_ctx;
int mirror_slot0 = 0;
const bool use_mirror_view =
Expand Down Expand Up @@ -1809,6 +1812,15 @@ static int run_target_layer_split_harness(
draft_gpu,
(dp.size() >= 5 && dp.substr(dp.size() - 5) == ".gguf")
? "gguf" : "safetensors");
if (g_draft_swa_window > 0) {
draft_weights.swa_window = g_draft_swa_window;
for (int il = 0; il < draft_weights.n_layer - 1; il++) {
draft_weights.layers[il].is_swa = true;
}
std::printf("[target-split] draft SWA layers: %d/%d (window=%d)\n",
draft_weights.n_layer - 1, draft_weights.n_layer,
draft_weights.swa_window);
}
const int cap = std::min(max_ctx, 4096);
if (!draft_feature_mirror_init(feature_ring, draft_backend,
draft_gpu, draft_gpu, cap)) {
Expand Down Expand Up @@ -2005,6 +2017,12 @@ int main(int argc, char ** argv) {
if (const char * s = std::getenv("DFLASH27B_FA_WINDOW")) {
g_fa_window = std::max(0, std::atoi(s));
}
if (const char * s = std::getenv("DFLASH27B_DRAFT_SWA")) {
g_draft_swa_window = std::max(0, std::atoi(s));
}
if (const char * s = std::getenv("DFLASH27B_DRAFT_CTX_MAX")) {
g_draft_ctx_max = std::max(0, std::atoi(s));
}
const char * target_path = argv[1];

// ---- Architecture detection ------------------------------------------
Expand Down Expand Up @@ -2183,6 +2201,12 @@ int main(int argc, char ** argv) {
else if (std::strncmp(argv[i], "-ctv=", 5) == 0) {
setenv("DFLASH27B_KV_V", argv[i] + 5, 1);
}
else if (std::strncmp(argv[i], "--draft-swa=", 12) == 0) {
g_draft_swa_window = std::max(0, std::atoi(argv[i] + 12));
}
else if (std::strncmp(argv[i], "--draft-ctx-max=", 16) == 0) {
g_draft_ctx_max = std::max(0, std::atoi(argv[i] + 16));
}
}

// The KV type may also have been chosen via -ctk/-ctv, which sets
Expand Down Expand Up @@ -2264,10 +2288,10 @@ int main(int argc, char ** argv) {
if (target_split_dflash) target_split_load_draft = true;
if (target_gpus.empty()) target_gpus.push_back(target_gpu);
if (target_gpus.size() == 1) target_gpu = target_gpus[0];
std::printf("[cfg] seq_verify=%d fast_rollback=%d ddtree=%d budget=%d temp=%.2f chain_seed=%d fa_window=%d draft_feature_mirror=%d target_gpu=%d draft_gpu=%d\n",
std::printf("[cfg] seq_verify=%d fast_rollback=%d ddtree=%d budget=%d temp=%.2f chain_seed=%d fa_window=%d draft_swa=%d draft_ctx_max=%d draft_feature_mirror=%d target_gpu=%d draft_gpu=%d\n",
(int)seq_verify, (int)fast_rollback, (int)ddtree_mode,
ddtree_budget, ddtree_temp, (int)ddtree_chain_seed, g_fa_window,
(int)draft_feature_mirror, target_gpu, draft_gpu);
g_draft_swa_window, g_draft_ctx_max, (int)draft_feature_mirror, target_gpu, draft_gpu);

int cuda_device_count = 0;
cudaGetDeviceCount(&cuda_device_count);
Expand Down Expand Up @@ -2343,6 +2367,16 @@ int main(int argc, char ** argv) {
}
std::printf("[draft] loaded\n");

// Apply --draft-swa=N: mark layers 0..n-2 as SWA, last layer stays full.
if (g_draft_swa_window > 0) {
dw.swa_window = g_draft_swa_window;
for (int il = 0; il < dw.n_layer - 1; il++) {
dw.layers[il].is_swa = true;
}
std::printf("[draft] SWA layers: %d/%d (window=%d)\n",
dw.n_layer - 1, dw.n_layer, dw.swa_window);
}

const int max_ctx = g_max_ctx_override > 0 ? g_max_ctx_override : 4096;
// Size the ssm_intermediate / conv_input_cache buffers to cover whichever
// verify mode we'll use. DDTree needs room for 1 + ddtree_budget tree nodes.
Expand Down Expand Up @@ -2711,6 +2745,11 @@ int main(int argc, char ** argv) {
std::fprintf(stderr, "[unpark] draft: %s\n", dflash27b_last_error());
stream_emit(-1); continue;
}
if (g_draft_swa_window > 0) {
dw.swa_window = g_draft_swa_window;
for (int il = 0; il < dw.n_layer - 1; il++)
dw.layers[il].is_swa = true;
}
draft_parked = false;
std::printf("[unpark] draft restored\n"); std::fflush(stdout);
}
Expand Down Expand Up @@ -2801,6 +2840,11 @@ int main(int argc, char ** argv) {
dflash27b_last_error());
stream_emit(-1); continue;
}
if (g_draft_swa_window > 0) {
dw.swa_window = g_draft_swa_window;
for (int il = 0; il < dw.n_layer - 1; il++)
dw.layers[il].is_swa = true;
}
draft_parked = false;
std::printf("[compress] draft restored\n"); std::fflush(stdout);
}
Expand Down Expand Up @@ -3451,7 +3495,8 @@ int main(int argc, char ** argv) {
// window are invisible to the draft but still in the target's KV
// cache (the target verify uses the full history).
constexpr int DRAFT_CTX_MAX = 2048;
const int draft_ctx = std::min(committed, DRAFT_CTX_MAX);
const int draft_ctx = std::min(committed,
std::max(DRAFT_CTX_MAX, g_draft_ctx_max));
const int draft_start = committed - draft_ctx;
int mirror_slot0 = 0;
const bool use_mirror_view =
Expand Down