Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 16 additions & 38 deletions dflash/src/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,23 +141,12 @@ struct TargetWeights {
int capture_layer_ids[DFLASH27B_DRAFT_N_TARGET_LAYERS] = {1, 16, 31, 46, 61};
};

struct TargetLoadPlan {
int layer_begin = 0; // inclusive
int layer_end = -1; // exclusive; <0 means all layers
bool load_output = true; // output_norm + lm_head
};

// Load a Q4_K_M target model from a GGUF file on disk.
// Returns false and sets last_error on failure.
bool load_target_gguf(const std::string & path,
ggml_backend_t backend,
TargetWeights & out);

bool load_target_gguf_partial(const std::string & path,
ggml_backend_t backend,
const TargetLoadPlan & plan,
TargetWeights & out);

void free_target_weights(TargetWeights & w);

// ─── Draft weights (z-lab DFlash, bf16) ───────────────────────────
Expand All @@ -174,6 +163,7 @@ struct DraftLayer {
ggml_tensor * w_gate;
ggml_tensor * w_up;
ggml_tensor * w_down;
bool is_swa = false; // sliding window attention (Qwen3.6 draft)
};

struct DraftWeights {
Expand All @@ -193,6 +183,7 @@ struct DraftWeights {
int head_dim = DFLASH27B_TARGET_HEAD_DIM; // 128
int n_embd = DFLASH27B_TARGET_HIDDEN; // 5120
int n_ff = DFLASH27B_TARGET_INTERMEDIATE; // 17408
int swa_window = 0; // sliding window size (0 = full attention, 2048 for Qwen3.6 draft)
};

bool load_draft_safetensors(const std::string & path,
Expand Down Expand Up @@ -228,12 +219,6 @@ struct TargetCache {
ggml_type kv_k_type = GGML_TYPE_Q8_0;
ggml_type kv_v_type = GGML_TYPE_Q8_0;

// When true, K is FWHT-rotated in the graph before writing to the
// standard-type cache (Q4_0/Q8_0/etc), and Q is rotated at attention
// time. This gives TurboQuant-style outlier spreading with fast FA
// kernels that work on all GPU architectures.
bool kv_k_rotated = false;

// Full-attention KV cache: one K and one V per full-attention layer.
// Layout: [head_dim, max_ctx, n_head_kv] f16, contiguous per layer.
std::vector<ggml_tensor *> attn_k; // size = n_full_attn_layers (16)
Expand Down Expand Up @@ -269,13 +254,15 @@ struct TargetCache {
std::vector<ggml_tensor *> conv_input_cache; // size = n_delta (48)

// Rolling target layer features captured during target forward passes.
// Shape [5 * hidden, target_feat_cap] bf16. target_feat_cap is typically
// << max_ctx (e.g. 4096) so the buffer stays small at 128K context. The
// graph writes to slot `(kv_start + i) % target_feat_cap` so positions
// beyond the cap wrap and overwrite older entries. Readers (draft) only
// need the last DRAFT_CTX_MAX positions, so wrap is invisible in
// practice. Fed into the draft graph's fc projection after a bf16→f32
// cast (ggml_get_to_fp32_cuda).
// Shape [5 * hidden, target_feat_cap] bf16 for single-seq caches, or
// [5 * hidden, target_feat_cap, n_seqs] for batched scratch caches.
// target_feat_cap is typically << max_ctx (e.g. 4096) so the buffer stays
// small at 128K context. The graph writes to slot
// `(kv_start + i) % target_feat_cap` so positions beyond the cap wrap and
// overwrite older entries. Readers (draft) only need the last
// DRAFT_CTX_MAX positions, so wrap is invisible in practice. Fed into the
// draft graph's fc projection after a bf16->f32 cast
// (dflash27b_launch_bf16_to_f32).
ggml_tensor * target_feat = nullptr;
int target_feat_cap = 0;
};
Expand Down Expand Up @@ -384,17 +371,8 @@ bool create_target_cache(const TargetWeights & w,
int max_verify_tokens,
ggml_backend_t backend,
TargetCache & out,
bool prefill_only = false);

bool create_target_cache_partial(const TargetWeights & w,
int max_ctx,
int max_verify_tokens,
ggml_backend_t backend,
TargetCache & out,
bool prefill_only,
int layer_begin,
int layer_end,
bool allocate_target_feat);
bool prefill_only = false,
int n_seqs = 1);

void free_target_cache(TargetCache & c);

Expand Down Expand Up @@ -434,15 +412,15 @@ struct DeltaNetCapture {
};

struct QwenGraphInputs {
ggml_tensor * inp_embed; // [hidden, n_tokens, 1] f32 pre-embedded by the caller
ggml_tensor * positions; // [4 * n_tokens] i32 (M-RoPE needs 4 per token)
ggml_tensor * inp_embed; // [hidden, n_tokens, n_seqs] f32; pre-embedded by the caller
ggml_tensor * positions; // [4 * n_tokens] i32; shared across n_seqs for the current batched probe
ggml_tensor * attn_mask; // optional [kv_len, n_tokens_padded] f32 (causal); nullptr for n_tokens==1
int n_tokens; // number of new tokens in this forward
int n_seqs = 1; // batch dimension; n_seqs>1 is capture-free and same-position only for now
int kv_start; // position where the new tokens begin
bool capture_layers; // if true, write captured layer features into cache.target_feat
bool capture_delta_intermediate = false; // if true, populate out_delta_captures
int fa_window = 0; // sliding window for FA layers: 0 = full attention
bool last_token_logits_only = false; // if true, only compute logits for last token (prefill optimization)
ggml_tensor * parent_ids = nullptr; // [n_tokens] i32; tree mode when non-null
};

Expand Down
Loading