Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions dflash/test/test_daemon_reset_merge_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import re
import unittest
from pathlib import Path


ROOT = Path(__file__).resolve().parents[1]
SOURCE = ROOT / "test" / "test_dflash.cpp"


class DaemonResetMergeResolutionTest(unittest.TestCase):
def test_daemon_reset_reuses_cache_and_frees_both_transient_graphs(self):
source = SOURCE.read_text()
match = re.search(
r"if \(!daemon_first_iter\) \{\n(?P<body>.*?)\n\s+\}\n\s+daemon_first_iter = false;",
source,
re.S,
)
self.assertIsNotNone(match, "daemon reset block not found")

body = match.group("body")
self.assertIn("step_graph_free(target_sg);", body)
self.assertIn("step_graph_free(draft_sg);", body)
self.assertIn("reset_target_cache(cache);", body)
self.assertNotIn("step_graph_destroy(target_sg);", body)
self.assertNotIn("step_graph_destroy(draft_sg);", body)
self.assertNotIn("free_target_cache(cache);", body)
self.assertNotIn("create_target_cache(", body)


if __name__ == "__main__":
unittest.main()
46 changes: 31 additions & 15 deletions dflash/test/test_dflash.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1213,7 +1213,15 @@ int main(int argc, char ** argv) {
std::fflush(stdout);
}

StepGraph sg;
// Two StepGraphs: target verify (huge graph ~3000 nodes) and draft forward
// (small graph ~200 nodes) have very different topologies. Sharing one
// ggml_gallocr made every call to one path see needs_realloc=true after a
// call to the other (n_nodes mismatch), forcing a graph re-walk and often
// a cudaMalloc on every spec-decode iteration (issue #55). Splitting the
// gallocrs lets each settle into a steady state and stop reallocating.
StepGraph target_sg;
StepGraph draft_sg;
StepGraph & sg = target_sg; // alias for prefill / target-verify call sites
bool daemon_first_iter = true;

while (true) {
Expand All @@ -1229,9 +1237,12 @@ int main(int argc, char ** argv) {
// Reset cache state between requests. On the first request the
// cache was promoted from prefill-only to full (with rollback
// tensors) by migrate_prefill_cache. On subsequent requests we
// just zero all state tensors in place — no GPU buffer free/alloc.
// just zero all state tensors in place and drop transient graph
// descriptors for both target/draft graphs; persistent gallocr
// buffers stay resident.
if (!daemon_first_iter) {
step_graph_free(sg);
step_graph_free(target_sg);
step_graph_free(draft_sg);
reset_target_cache(cache);
}
daemon_first_iter = false;
Expand Down Expand Up @@ -1409,7 +1420,8 @@ int main(int argc, char ** argv) {

// Promote prefill-only cache to full decode cache
auto t_mig0 = std::chrono::steady_clock::now();
step_graph_destroy(sg);
step_graph_destroy(target_sg);
step_graph_destroy(draft_sg);
if (!migrate_prefill_cache(w, max_ctx, max_verify_tokens, backend, cache)) {
std::fprintf(stderr, "cache migration: %s\n", dflash27b_last_error());
return 1;
Expand Down Expand Up @@ -1490,7 +1502,8 @@ int main(int argc, char ** argv) {
// Promote prefill-only cache to full decode cache with rollback tensors.
// Copies KV, SSM/conv state, and target_feat device→device (~1 ms).
auto t_mig0 = std::chrono::steady_clock::now();
step_graph_destroy(sg);
step_graph_destroy(target_sg);
step_graph_destroy(draft_sg);
if (!migrate_prefill_cache(w, max_ctx, max_verify_tokens, backend, cache)) {
std::fprintf(stderr, "cache migration: %s\n", dflash27b_last_error());
return 1;
Expand Down Expand Up @@ -1548,14 +1561,16 @@ int main(int argc, char ** argv) {
const int draft_ctx = std::min(committed, DRAFT_CTX_MAX);
const int draft_start = committed - draft_ctx;

// 2) Draft forward
if (!build_draft_step(sg, dw, w, backend, /*ctx_len=*/draft_ctx)) {
// 2) Draft forward — uses draft_sg so its gallocr settles into a
// small-graph shape (≈200 nodes) instead of bouncing with the target
// verify path's huge-graph shape (≈3000 nodes) (issue #55).
if (!build_draft_step(draft_sg, dw, w, backend, /*ctx_len=*/draft_ctx)) {
std::fprintf(stderr, "draft build failed\n"); return 1;
}
auto T_draft_build = sync_us();
tt_draft_build += std::chrono::duration<double, std::micro>(T_draft_build - T0).count();

ggml_backend_tensor_set(sg.inp_embed, noise_embed_buf.data(), 0,
ggml_backend_tensor_set(draft_sg.inp_embed, noise_embed_buf.data(), 0,
sizeof(float) * noise_embed_buf.size());

// target_hidden_cat: copy the draft-window slice of cache.target_feat
Expand All @@ -1572,13 +1587,13 @@ int main(int argc, char ** argv) {

dflash27b_launch_bf16_to_f32(
(const char *)cache.target_feat->data + (size_t)slot0 * elt_feat * fc_in,
sg.target_hidden_cat->data,
draft_sg.target_hidden_cat->data,
(size_t)pre_n * fc_in,
nullptr);
if (post_n > 0) {
dflash27b_launch_bf16_to_f32(
(const char *)cache.target_feat->data,
(char *)sg.target_hidden_cat->data + (size_t)pre_n * fc_in * sizeof(float),
(char *)draft_sg.target_hidden_cat->data + (size_t)pre_n * fc_in * sizeof(float),
(size_t)post_n * fc_in,
nullptr);
}
Expand All @@ -1587,17 +1602,17 @@ int main(int argc, char ** argv) {

for (int i = 0; i < q_len; i++) pos_q_buf[i] = draft_ctx + i;
for (int i = 0; i < draft_ctx + q_len; i++) pos_k_buf[i] = i;
ggml_backend_tensor_set(sg.positions, pos_q_buf.data(), 0, sizeof(int32_t) * q_len);
ggml_backend_tensor_set(sg.positions_k, pos_k_buf.data(), 0, sizeof(int32_t) * (draft_ctx + q_len));
ggml_backend_tensor_set(draft_sg.positions, pos_q_buf.data(), 0, sizeof(int32_t) * q_len);
ggml_backend_tensor_set(draft_sg.positions_k, pos_k_buf.data(), 0, sizeof(int32_t) * (draft_ctx + q_len));
auto T_draft_set = sync_us();
tt_draft_set += std::chrono::duration<double, std::micro>(T_draft_set - T_draft_copy).count();

auto st = ggml_backend_graph_compute(backend, sg.gf);
auto st = ggml_backend_graph_compute(backend, draft_sg.gf);
if (st != GGML_STATUS_SUCCESS) { std::fprintf(stderr, "draft compute %d\n", (int)st); return 1; }
auto T_draft_compute = sync_us();
tt_draft_compute += std::chrono::duration<double, std::micro>(T_draft_compute - T_draft_set).count();

ggml_backend_tensor_get(sg.logits, draft_logits_buf.data(), 0,
ggml_backend_tensor_get(draft_sg.logits, draft_logits_buf.data(), 0,
sizeof(float) * vocab * q_len);
for (int i = 0; i < q_len; i++) {
draft_tok[i] = argmax_f32(draft_logits_buf.data() + (size_t)i * vocab, vocab);
Expand Down Expand Up @@ -2359,7 +2374,8 @@ int main(int argc, char ** argv) {

} // end while(true)

step_graph_destroy(sg);
step_graph_destroy(target_sg);
step_graph_destroy(draft_sg);
free_target_cache(cache);
free_draft_weights(dw);
free_target_weights(w);
Expand Down