diff --git a/dflash/deps/llama.cpp b/dflash/deps/llama.cpp index 706cd1f6..e2d98e3b 160000 --- a/dflash/deps/llama.cpp +++ b/dflash/deps/llama.cpp @@ -1 +1 @@ -Subproject commit 706cd1f6b4c8423337e1565c9c3f8ac49678de1b +Subproject commit e2d98e3b7539521466a6408dffb5e39409c29ea6 diff --git a/dflash/scripts/run.py b/dflash/scripts/run.py index 108a7520..9b02892d 100644 --- a/dflash/scripts/run.py +++ b/dflash/scripts/run.py @@ -83,6 +83,13 @@ def main(): help="Sliding window for FA layers (KV positions). 0 = full " "attention. Default 2048 (set in C++); only kicks in " "once kv_cache > window.") + ap.add_argument("--draft-swa", type=int, default=None, + help="Draft SWA window (Qwen3.6 pattern: layers 0..n-2 use " + "sliding window). 0 = disabled. Typical value: 2048.") + ap.add_argument("--draft-ctx-max", type=int, default=None, + help="Draft context cap (default 2048). Raise to let the " + "full-attention layer see more context. Requires " + "--draft-swa < this value to activate SWA truncation.") ap.add_argument("--max-ctx", type=int, default=0, help="Override max KV context (default: auto-fit " "prompt+n_gen+block, aligned to 256). Passing a " @@ -130,6 +137,10 @@ def main(): env["DFLASH27B_KV_TQ3"] = "1" if args.fa_window is not None: env["DFLASH27B_FA_WINDOW"] = str(args.fa_window) + if args.draft_swa is not None: + env["DFLASH27B_DRAFT_SWA"] = str(args.draft_swa) + if args.draft_ctx_max is not None: + env["DFLASH27B_DRAFT_CTX_MAX"] = str(args.draft_ctx_max) with tempfile.TemporaryDirectory() as tmp: in_bin = os.path.join(tmp, "prompt.bin") diff --git a/dflash/src/internal.h b/dflash/src/internal.h index 1a1011ee..80b61dcf 100644 --- a/dflash/src/internal.h +++ b/dflash/src/internal.h @@ -174,6 +174,7 @@ struct DraftLayer { ggml_tensor * w_gate; ggml_tensor * w_up; ggml_tensor * w_down; + bool is_swa = false; // true for SWA layers (Qwen3.6 pattern) }; struct DraftWeights { @@ -193,6 +194,7 @@ struct DraftWeights { int head_dim = DFLASH27B_TARGET_HEAD_DIM; // 128 int n_embd = DFLASH27B_TARGET_HIDDEN; // 5120 int n_ff = DFLASH27B_TARGET_INTERMEDIATE; // 17408 + int swa_window = 0; // sliding window size (0 = disabled) }; bool load_draft_safetensors(const std::string & path, diff --git a/dflash/src/qwen3_dflash_graph.cpp b/dflash/src/qwen3_dflash_graph.cpp index 638bd873..751d4d8a 100644 --- a/dflash/src/qwen3_dflash_graph.cpp +++ b/dflash/src/qwen3_dflash_graph.cpp @@ -42,13 +42,11 @@ DraftGraphOutputs build_draft_graph( const int q_len = DFLASH27B_DRAFT_BLOCK_SIZE; const int ctx_len = in.ctx_len; - const int total_k = ctx_len + q_len; const int n_head = w.n_head; const int n_kv = w.n_head_kv; const int head_dim = w.head_dim; const float eps = DFLASH27B_RMS_EPS; const float rope_base = DFLASH27B_ROPE_THETA; - (void)ctx_len; // used only via input tensor shapes // ── 1. Feature fusion: target_feat = rms_norm(fc @ target_hidden_cat, hidden_norm) // fc: [5*hidden, hidden] (ggml: ne[0]=5*hidden, ne[1]=hidden) @@ -65,6 +63,12 @@ DraftGraphOutputs build_draft_graph( for (int il = 0; il < w.n_layer; il++) { const DraftLayer & L = w.layers[il]; + // ── SWA: determine effective context for this layer + const bool use_swa = L.is_swa && w.swa_window > 0 && ctx_len > w.swa_window; + const int eff_ctx = use_swa ? w.swa_window : ctx_len; + const int eff_total_k = eff_ctx + q_len; + const int ctx_offset = use_swa ? (ctx_len - w.swa_window) : 0; + // ── 2a. Attention pre-norm ggml_tensor * hn = ggml_rms_norm(ctx, h, eps); hn = ggml_mul(ctx, hn, L.attn_norm); @@ -78,43 +82,56 @@ DraftGraphOutputs build_draft_graph( // ── 2c. K and V from target_feat AND noise, then concat along sequence // wk, wv: [hidden, kv_dim=1024] - ggml_tensor * Kctx = ggml_mul_mat(ctx, L.wk, target_feat); // [kv_dim, ctx_len, 1] - ggml_tensor * Kn = ggml_mul_mat(ctx, L.wk, hn); // [kv_dim, q_len, 1] - ggml_tensor * Vctx = ggml_mul_mat(ctx, L.wv, target_feat); + // For SWA layers: window target_feat to last swa_window positions. + ggml_tensor * tf = target_feat; + if (use_swa) { + tf = ggml_view_3d(ctx, target_feat, + w.n_embd, eff_ctx, 1, + target_feat->nb[1], target_feat->nb[2], + target_feat->nb[1] * ctx_offset); + } + ggml_tensor * Kctx = ggml_mul_mat(ctx, L.wk, tf); // [kv_dim, eff_ctx, 1] + ggml_tensor * Kn = ggml_mul_mat(ctx, L.wk, hn); // [kv_dim, q_len, 1] + ggml_tensor * Vctx = ggml_mul_mat(ctx, L.wv, tf); ggml_tensor * Vn = ggml_mul_mat(ctx, L.wv, hn); // concat along ne[1] (sequence) — ggml_concat second arg dim=1 - ggml_tensor * K = ggml_concat(ctx, Kctx, Kn, 1); // [kv_dim, total_k, 1] + ggml_tensor * K = ggml_concat(ctx, Kctx, Kn, 1); // [kv_dim, eff_total_k, 1] ggml_tensor * V = ggml_concat(ctx, Vctx, Vn, 1); // Per-head k_norm - K = ggml_reshape_3d(ctx, K, head_dim, n_kv, total_k); + K = ggml_reshape_3d(ctx, K, head_dim, n_kv, eff_total_k); K = ggml_rms_norm(ctx, K, eps); K = ggml_mul (ctx, K, L.k_norm); - V = ggml_reshape_3d(ctx, V, head_dim, n_kv, total_k); + V = ggml_reshape_3d(ctx, V, head_dim, n_kv, eff_total_k); // ── 2d. RoPE (NEOX, theta=10M) - // Q: positions_q [q_len] values [ctx_len..ctx_len+q_len-1] - // K: positions_k [total_k] values [0..total_k-1] + // Q: positions_q [q_len] values [ctx_len..ctx_len+q_len-1] + // K: positions_k [eff_total_k] — for SWA, starts from ctx_offset + ggml_tensor * pk = in.positions_k; + if (use_swa) { + pk = ggml_view_1d(ctx, in.positions_k, eff_total_k, + ctx_offset * ggml_element_size(in.positions_k)); + } Q = ggml_rope_ext(ctx, Q, in.positions_q, /*freq_factors=*/nullptr, head_dim, GGML_ROPE_TYPE_NEOX, /*n_ctx_orig=*/0, rope_base, /*freq_scale=*/1.0f, /*ext_factor=*/0.0f, /*attn_factor=*/1.0f, /*beta_fast=*/0.0f, /*beta_slow=*/0.0f); - K = ggml_rope_ext(ctx, K, in.positions_k, nullptr, + K = ggml_rope_ext(ctx, K, pk, nullptr, head_dim, GGML_ROPE_TYPE_NEOX, 0, rope_base, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f); // ── 2e. Permute into the layout flash_attn_ext wants // q: [n_embd_k=head_dim, n_batch=q_len, n_head, ne3] - // k: [n_embd_k=head_dim, n_kv=total_k, n_head_kv, ne3] - // v: [n_embd_v=head_dim, n_kv=total_k, n_head_kv, ne3] (not transposed) - Q = ggml_permute(ctx, Q, 0, 2, 1, 3); // [head_dim, q_len, n_head, 1] + // k: [n_embd_k=head_dim, n_kv=eff_total_k, n_head_kv, ne3] + // v: [n_embd_v=head_dim, n_kv=eff_total_k, n_head_kv, ne3] (not transposed) + Q = ggml_permute(ctx, Q, 0, 2, 1, 3); // [head_dim, q_len, n_head, 1] Q = ggml_cont (ctx, Q); - K = ggml_permute(ctx, K, 0, 2, 1, 3); // [head_dim, total_k, n_kv, 1] + K = ggml_permute(ctx, K, 0, 2, 1, 3); // [head_dim, eff_total_k, n_kv, 1] K = ggml_cont (ctx, K); - V = ggml_permute(ctx, V, 0, 2, 1, 3); // [head_dim, total_k, n_kv, 1] + V = ggml_permute(ctx, V, 0, 2, 1, 3); // [head_dim, eff_total_k, n_kv, 1] V = ggml_cont (ctx, V); // ── 2f. Non-causal flash attention; GQA broadcast handled internally. diff --git a/dflash/test/test_dflash.cpp b/dflash/test/test_dflash.cpp index 18532cd7..6f73796c 100644 --- a/dflash/test/test_dflash.cpp +++ b/dflash/test/test_dflash.cpp @@ -136,6 +136,8 @@ static constexpr int KQ_MASK_PAD = 32; static int g_kq_stride_pad = KQ_MASK_PAD; // overridden to 256 when TBQ KV is active static int g_max_ctx_override = 0; // overridden by --max-ctx=N (default 4096) static int g_fa_window = 2048; // overridden by DFLASH27B_FA_WINDOW=N +static int g_draft_swa_window = 0; // draft SWA window (0 = disabled); --draft-swa=N +static int g_draft_ctx_max = 4096; // draft context cap; --draft-ctx-max=N static int align_up(int x, int a) { return ((x + a - 1) / a) * a; } // F16 encoding for the two values we use: 0 and -inf. @@ -1551,7 +1553,8 @@ static bool run_target_layer_split_dflash_decode( } constexpr int DRAFT_CTX_MAX = 2048; - const int draft_ctx = std::min(committed, std::min(feature_ring.cap, DRAFT_CTX_MAX)); + const int draft_ctx = std::min(committed, std::min(feature_ring.cap, + std::max(DRAFT_CTX_MAX, g_draft_ctx_max))); const int draft_start = committed - draft_ctx; int mirror_slot0 = 0; const bool use_mirror_view = @@ -1809,6 +1812,15 @@ static int run_target_layer_split_harness( draft_gpu, (dp.size() >= 5 && dp.substr(dp.size() - 5) == ".gguf") ? "gguf" : "safetensors"); + if (g_draft_swa_window > 0) { + draft_weights.swa_window = g_draft_swa_window; + for (int il = 0; il < draft_weights.n_layer - 1; il++) { + draft_weights.layers[il].is_swa = true; + } + std::printf("[target-split] draft SWA layers: %d/%d (window=%d)\n", + draft_weights.n_layer - 1, draft_weights.n_layer, + draft_weights.swa_window); + } const int cap = std::min(max_ctx, 4096); if (!draft_feature_mirror_init(feature_ring, draft_backend, draft_gpu, draft_gpu, cap)) { @@ -2005,6 +2017,12 @@ int main(int argc, char ** argv) { if (const char * s = std::getenv("DFLASH27B_FA_WINDOW")) { g_fa_window = std::max(0, std::atoi(s)); } + if (const char * s = std::getenv("DFLASH27B_DRAFT_SWA")) { + g_draft_swa_window = std::max(0, std::atoi(s)); + } + if (const char * s = std::getenv("DFLASH27B_DRAFT_CTX_MAX")) { + g_draft_ctx_max = std::max(0, std::atoi(s)); + } const char * target_path = argv[1]; // ---- Architecture detection ------------------------------------------ @@ -2183,6 +2201,12 @@ int main(int argc, char ** argv) { else if (std::strncmp(argv[i], "-ctv=", 5) == 0) { setenv("DFLASH27B_KV_V", argv[i] + 5, 1); } + else if (std::strncmp(argv[i], "--draft-swa=", 12) == 0) { + g_draft_swa_window = std::max(0, std::atoi(argv[i] + 12)); + } + else if (std::strncmp(argv[i], "--draft-ctx-max=", 16) == 0) { + g_draft_ctx_max = std::max(0, std::atoi(argv[i] + 16)); + } } // The KV type may also have been chosen via -ctk/-ctv, which sets @@ -2264,10 +2288,10 @@ int main(int argc, char ** argv) { if (target_split_dflash) target_split_load_draft = true; if (target_gpus.empty()) target_gpus.push_back(target_gpu); if (target_gpus.size() == 1) target_gpu = target_gpus[0]; - std::printf("[cfg] seq_verify=%d fast_rollback=%d ddtree=%d budget=%d temp=%.2f chain_seed=%d fa_window=%d draft_feature_mirror=%d target_gpu=%d draft_gpu=%d\n", + std::printf("[cfg] seq_verify=%d fast_rollback=%d ddtree=%d budget=%d temp=%.2f chain_seed=%d fa_window=%d draft_swa=%d draft_ctx_max=%d draft_feature_mirror=%d target_gpu=%d draft_gpu=%d\n", (int)seq_verify, (int)fast_rollback, (int)ddtree_mode, ddtree_budget, ddtree_temp, (int)ddtree_chain_seed, g_fa_window, - (int)draft_feature_mirror, target_gpu, draft_gpu); + g_draft_swa_window, g_draft_ctx_max, (int)draft_feature_mirror, target_gpu, draft_gpu); int cuda_device_count = 0; cudaGetDeviceCount(&cuda_device_count); @@ -2343,6 +2367,16 @@ int main(int argc, char ** argv) { } std::printf("[draft] loaded\n"); + // Apply --draft-swa=N: mark layers 0..n-2 as SWA, last layer stays full. + if (g_draft_swa_window > 0) { + dw.swa_window = g_draft_swa_window; + for (int il = 0; il < dw.n_layer - 1; il++) { + dw.layers[il].is_swa = true; + } + std::printf("[draft] SWA layers: %d/%d (window=%d)\n", + dw.n_layer - 1, dw.n_layer, dw.swa_window); + } + const int max_ctx = g_max_ctx_override > 0 ? g_max_ctx_override : 4096; // Size the ssm_intermediate / conv_input_cache buffers to cover whichever // verify mode we'll use. DDTree needs room for 1 + ddtree_budget tree nodes. @@ -2711,6 +2745,11 @@ int main(int argc, char ** argv) { std::fprintf(stderr, "[unpark] draft: %s\n", dflash27b_last_error()); stream_emit(-1); continue; } + if (g_draft_swa_window > 0) { + dw.swa_window = g_draft_swa_window; + for (int il = 0; il < dw.n_layer - 1; il++) + dw.layers[il].is_swa = true; + } draft_parked = false; std::printf("[unpark] draft restored\n"); std::fflush(stdout); } @@ -2801,6 +2840,11 @@ int main(int argc, char ** argv) { dflash27b_last_error()); stream_emit(-1); continue; } + if (g_draft_swa_window > 0) { + dw.swa_window = g_draft_swa_window; + for (int il = 0; il < dw.n_layer - 1; il++) + dw.layers[il].is_swa = true; + } draft_parked = false; std::printf("[compress] draft restored\n"); std::fflush(stdout); } @@ -3451,7 +3495,8 @@ int main(int argc, char ** argv) { // window are invisible to the draft but still in the target's KV // cache (the target verify uses the full history). constexpr int DRAFT_CTX_MAX = 2048; - const int draft_ctx = std::min(committed, DRAFT_CTX_MAX); + const int draft_ctx = std::min(committed, + std::max(DRAFT_CTX_MAX, g_draft_ctx_max)); const int draft_start = committed - draft_ctx; int mirror_slot0 = 0; const bool use_mirror_view =