Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions dflash/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ if(DFLASH27B_TESTS)
target_include_directories(test_vs_oracle PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_link_libraries(test_vs_oracle PRIVATE dflash27b ggml ggml-cuda)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_swa_mask_contract.cpp")
add_executable(test_draft_swa_mask_contract test/test_draft_swa_mask_contract.cpp)
target_include_directories(test_draft_swa_mask_contract PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_link_libraries(test_draft_swa_mask_contract PRIVATE dflash27b ggml ggml-cuda)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/smoke_load_target.cpp")
add_executable(smoke_load_target test/smoke_load_target.cpp)
target_include_directories(smoke_load_target PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
Expand Down
15 changes: 14 additions & 1 deletion dflash/src/dflash_graph.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
// Shared inputs/outputs for the DFlash draft graph builder.
#pragma once

#include <cstdint>
#include <vector>

#include "ggml.h"

namespace dflash27b {
Expand All @@ -13,11 +16,15 @@ struct DraftGraphInputs {
ggml_tensor * target_hidden_cat;// [5*hidden, ctx_len, 1] f32
ggml_tensor * positions_q; // [q_len] i32 values [ctx_len..ctx_len+q_len-1]
ggml_tensor * positions_k; // [ctx_len+q_len] i32 values [0..ctx_len+q_len-1]
// Optional SWA mask for long-context sliding-attention layers.
// Shape [kv_len, q_len] or padded [kv_pad, q_pad], type F16, values
// 0 for visible positions and -inf for masked positions.
ggml_tensor * attn_mask = nullptr;
// Optional: if non-null, the graph projects final hidden states through
// this LM head (shape [hidden, vocab]) and returns logits instead of
// hidden states. Used for DFlash integration where the draft shares the
// target's lm_head.
ggml_tensor * lm_head;
ggml_tensor * lm_head = nullptr;
};

struct DraftGraphOutputs {
Expand All @@ -30,4 +37,10 @@ DraftGraphOutputs build_draft_graph(
const DraftWeights & w,
const DraftGraphInputs & in);

bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len);
void build_draft_swa_mask(std::vector<uint16_t> & out,
int ctx_len,
int q_len,
int swa_window);

} // namespace dflash27b
66 changes: 65 additions & 1 deletion dflash/src/qwen3_dflash_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,46 @@
#include "internal.h"
#include "dflash_graph.h"

#include <algorithm>
#include <cmath>
#include <cstdio>

namespace dflash27b {

bool draft_graph_needs_swa_mask(const DraftWeights & w, int ctx_len) {
if (w.swa_window <= 0) {
return false;
}
const int total_k = ctx_len + DFLASH27B_DRAFT_BLOCK_SIZE;
if (total_k <= w.swa_window) {
return false;
}
for (int il = 0; il < w.n_layer; ++il) {
if (w.layers[il].is_swa) {
return true;
}
}
return false;
}

void build_draft_swa_mask(std::vector<uint16_t> & out,
int ctx_len,
int q_len,
int swa_window) {
static constexpr uint16_t F16_ZERO = 0x0000;
static constexpr uint16_t F16_NEG_INF = 0xFC00;

const int total_k = ctx_len + q_len;
out.assign((size_t)total_k * q_len, F16_NEG_INF);
for (int q = 0; q < q_len; ++q) {
const int abs_q = ctx_len + q;
const int min_k = std::max(0, abs_q - swa_window);
for (int k = min_k; k < total_k; ++k) {
out[(size_t)q * total_k + k] = F16_ZERO;
}
}
}

DraftGraphOutputs build_draft_graph(
ggml_context * ctx,
const DraftWeights & w,
Expand Down Expand Up @@ -118,8 +154,36 @@ DraftGraphOutputs build_draft_graph(
V = ggml_cont (ctx, V);

// ── 2f. Non-causal flash attention; GQA broadcast handled internally.
// For SWA layers (Qwen3.6 draft): apply sliding window mask
// limiting context K/V to the last `swa_window` positions.
const float scale = 1.0f / std::sqrt((float)head_dim);
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, /*mask=*/nullptr,
ggml_tensor * attn_mask = nullptr;
if (L.is_swa && w.swa_window > 0 && total_k > w.swa_window) {
if (!in.attn_mask) {
set_last_error("build_draft_graph: SWA layer requires a non-null attn_mask");
return {};
}
if (in.attn_mask->type != GGML_TYPE_F16) {
char buf[128];
std::snprintf(buf, sizeof(buf),
"build_draft_graph: SWA attn_mask must be F16, got %s",
ggml_type_name(in.attn_mask->type));
set_last_error(buf);
return {};
}
if (in.attn_mask->ne[0] < total_k || in.attn_mask->ne[1] < q_len) {
char buf[160];
std::snprintf(buf, sizeof(buf),
"build_draft_graph: SWA attn_mask too small (%lld x %lld, need >= %d x %d)",
(long long)in.attn_mask->ne[0],
(long long)in.attn_mask->ne[1],
total_k, q_len);
set_last_error(buf);
return {};
}
attn_mask = in.attn_mask;
}
ggml_tensor * attn = ggml_flash_attn_ext(ctx, Q, K, V, attn_mask,
scale, /*max_bias=*/0.0f,
/*logit_softcap=*/0.0f);
// attn result: [n_embd_v=head_dim, n_head, n_batch=q_len, 1]
Expand Down
12 changes: 12 additions & 0 deletions dflash/test/smoke_draft_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ int main(int argc, char ** argv) {
ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, fc_in, ctx_len, 1);
ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, q_len);
ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, ctx_len + q_len);
ggml_tensor * attn_mask = nullptr;
if (draft_graph_needs_swa_mask(w, ctx_len)) {
attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, ctx_len + q_len, q_len);
ggml_set_name(attn_mask, "draft_swa_mask");
ggml_set_input(attn_mask);
}
ggml_set_name(noise_embed, "noise_embed");
ggml_set_name(target_hid, "target_hidden_cat");
ggml_set_name(pos_q, "positions_q");
Expand All @@ -101,6 +107,7 @@ int main(int argc, char ** argv) {
gi.target_hidden_cat = target_hid;
gi.positions_q = pos_q;
gi.positions_k = pos_k;
gi.attn_mask = attn_mask;

DraftGraphOutputs go = build_draft_graph(gctx, w, gi);
if (!go.hidden_states) { std::fprintf(stderr, "build_draft_graph returned null\n"); return 1; }
Expand Down Expand Up @@ -141,6 +148,11 @@ int main(int argc, char ** argv) {
for (int i = 0; i < ctx_len + q_len; i++) pk[i] = i;
ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size());
}
if (attn_mask) {
std::vector<uint16_t> mask;
build_draft_swa_mask(mask, ctx_len, q_len, w.swa_window);
ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size());
}

// ── 7. Compute
auto status = ggml_backend_graph_compute(backend, gf);
Expand Down
177 changes: 177 additions & 0 deletions dflash/test/test_draft_swa_mask_contract.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#include "dflash_graph.h"
#include "internal.h"

#include "ggml.h"

#include <cstdio>
#include <vector>

using namespace dflash27b;

namespace {

struct GraphCase {
bool is_swa = false;
int swa_window = 0;
int ctx_len = 0;
bool provide_mask = false;
bool expect_mask = false;
const char * label = "";
};

ggml_tensor * new_vec(ggml_context * ctx, int64_t n) {
return ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n);
}

ggml_tensor * new_mat(ggml_context * ctx, int64_t ne0, int64_t ne1) {
return ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne0, ne1);
}

bool run_case(const GraphCase & tc) {
ggml_init_params ip{};
ip.mem_size = 2 * 1024 * 1024;
ip.mem_buffer = nullptr;
ip.no_alloc = true;
ggml_context * ctx = ggml_init(ip);
if (!ctx) {
std::fprintf(stderr, "FAIL %s: ggml_init failed\n", tc.label);
return false;
}

constexpr int hidden = 8;
constexpr int n_head = 2;
constexpr int n_kv = 1;
constexpr int head_dim = 4;
constexpr int q_len = DFLASH27B_DRAFT_BLOCK_SIZE;
constexpr int inter = 12;
constexpr int fc_in = 5 * hidden;
const int total_k = tc.ctx_len + q_len;

DraftWeights w{};
w.n_layer = 1;
w.n_head = n_head;
w.n_head_kv = n_kv;
w.head_dim = head_dim;
w.swa_window = tc.swa_window;
w.layers.resize(1);

w.fc = new_mat(ctx, fc_in, hidden);
w.hidden_norm = new_vec(ctx, hidden);
w.out_norm = new_vec(ctx, hidden);

DraftLayer & layer = w.layers[0];
layer.attn_norm = new_vec(ctx, hidden);
layer.ffn_norm = new_vec(ctx, hidden);
layer.wq = new_mat(ctx, hidden, n_head * head_dim);
layer.wk = new_mat(ctx, hidden, n_kv * head_dim);
layer.wv = new_mat(ctx, hidden, n_kv * head_dim);
layer.wo = new_mat(ctx, n_head * head_dim, hidden);
layer.q_norm = new_vec(ctx, head_dim);
layer.k_norm = new_vec(ctx, head_dim);
layer.w_gate = new_mat(ctx, hidden, inter);
layer.w_up = new_mat(ctx, hidden, inter);
layer.w_down = new_mat(ctx, inter, hidden);
layer.is_swa = tc.is_swa;

ggml_tensor * noise_embed = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden, q_len, 1);
ggml_tensor * target_hidden_cat = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, fc_in, tc.ctx_len, 1);
ggml_tensor * positions_q = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, q_len);
ggml_tensor * positions_k = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, total_k);
ggml_tensor * attn_mask = nullptr;
if (tc.provide_mask) {
attn_mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, total_k, q_len);
ggml_set_name(attn_mask, "draft_swa_mask");
ggml_set_input(attn_mask);
}

ggml_set_input(noise_embed);
ggml_set_input(target_hidden_cat);
ggml_set_input(positions_q);
ggml_set_input(positions_k);

DraftGraphInputs in{};
in.ctx_len = tc.ctx_len;
in.noise_embed = noise_embed;
in.target_hidden_cat = target_hidden_cat;
in.positions_q = positions_q;
in.positions_k = positions_k;
in.attn_mask = attn_mask;

DraftGraphOutputs out = build_draft_graph(ctx, w, in);
if (!out.hidden_states) {
std::fprintf(stderr, "FAIL %s: build_draft_graph failed: %s\n",
tc.label, dflash27b_last_error());
ggml_free(ctx);
return false;
}

ggml_cgraph * gf = ggml_new_graph_custom(ctx, 256, false);
ggml_build_forward_expand(gf, out.hidden_states);

ggml_tensor * flash = nullptr;
for (int i = 0; i < ggml_graph_n_nodes(gf); ++i) {
ggml_tensor * node = ggml_graph_node(gf, i);
if (node && node->op == GGML_OP_FLASH_ATTN_EXT) {
flash = node;
break;
}
}

if (!flash) {
std::fprintf(stderr, "FAIL %s: no flash_attn_ext node found\n", tc.label);
ggml_free(ctx);
return false;
}

const bool got_mask = flash->src[3] != nullptr;
if (got_mask != tc.expect_mask) {
std::fprintf(stderr, "FAIL %s: expected mask=%d got mask=%d\n",
tc.label, tc.expect_mask ? 1 : 0, got_mask ? 1 : 0);
ggml_free(ctx);
return false;
}
if (tc.expect_mask && flash->src[3] != attn_mask) {
std::fprintf(stderr, "FAIL %s: flash_attn_ext did not use caller mask tensor\n", tc.label);
ggml_free(ctx);
return false;
}

std::printf("PASS %s\n", tc.label);
ggml_free(ctx);
return true;
}

} // namespace

int main() {
std::vector<GraphCase> cases(3);
cases[0].is_swa = true;
cases[0].swa_window = 8;
cases[0].ctx_len = 12;
cases[0].provide_mask = true;
cases[0].expect_mask = true;
cases[0].label = "swa-long-context-wires-mask";

cases[1].is_swa = false;
cases[1].swa_window = 8;
cases[1].ctx_len = 12;
cases[1].provide_mask = true;
cases[1].expect_mask = false;
cases[1].label = "non-swa-layer-ignores-mask";

cases[2].is_swa = true;
cases[2].swa_window = 64;
cases[2].ctx_len = 12;
cases[2].provide_mask = true;
cases[2].expect_mask = false;
cases[2].label = "short-context-keeps-full-attn";

int failed = 0;
for (const GraphCase & tc : cases) {
if (!run_case(tc)) {
++failed;
}
}

return failed == 0 ? 0 : 1;
}
12 changes: 12 additions & 0 deletions dflash/test/test_vs_oracle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ int main(int argc, char ** argv) {
ggml_tensor * target_hid = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, m.fc_in, m.ctx_len, 1);
ggml_tensor * pos_q = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.q_len);
ggml_tensor * pos_k = ggml_new_tensor_1d(gctx, GGML_TYPE_I32, m.ctx_len + m.q_len);
ggml_tensor * attn_mask = nullptr;
if (draft_graph_needs_swa_mask(w, m.ctx_len)) {
attn_mask = ggml_new_tensor_2d(gctx, GGML_TYPE_F16, m.ctx_len + m.q_len, m.q_len);
ggml_set_name(attn_mask, "draft_swa_mask");
ggml_set_input(attn_mask);
}
ggml_set_name(noise_embed, "noise_embed");
ggml_set_name(target_hid, "target_hidden_cat");
ggml_set_name(pos_q, "positions_q");
Expand All @@ -132,6 +138,7 @@ int main(int argc, char ** argv) {
gi.target_hidden_cat = target_hid;
gi.positions_q = pos_q;
gi.positions_k = pos_k;
gi.attn_mask = attn_mask;
DraftGraphOutputs go = build_draft_graph(gctx, w, gi);
if (!go.hidden_states) return 1;
ggml_set_output(go.hidden_states);
Expand All @@ -154,6 +161,11 @@ int main(int argc, char ** argv) {
for (int i = 0; i < m.ctx_len + m.q_len; i++) pk[i] = i;
ggml_backend_tensor_set(pos_q, pq.data(), 0, sizeof(int32_t) * pq.size());
ggml_backend_tensor_set(pos_k, pk.data(), 0, sizeof(int32_t) * pk.size());
if (attn_mask) {
std::vector<uint16_t> mask;
build_draft_swa_mask(mask, m.ctx_len, m.q_len, w.swa_window);
ggml_backend_tensor_set(attn_mask, mask.data(), 0, sizeof(uint16_t) * mask.size());
}

// Compute
auto status = ggml_backend_graph_compute(backend, gf);
Expand Down