Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
7dc8cda
duplicate cache
yubofredwang Jul 2, 2025
6d5cbc6
more
yubofredwang Jul 2, 2025
c49b341
fix padding copy
yubofredwang Jul 2, 2025
2ed27b0
to revert
yubofredwang Jul 2, 2025
8d1ac5c
Merge branch 'main' into support-duplicate-cache
yubofredwang Jul 2, 2025
0399f42
remove comments
yubofredwang Jul 2, 2025
084d94a
fix comments
yubofredwang Jul 2, 2025
75603e4
revert code
yubofredwang Jul 2, 2025
0947d4e
format
yubofredwang Jul 2, 2025
ea6935a
push
yubofredwang Jul 2, 2025
dafa283
done
yubofredwang Jul 2, 2025
6d2a3e8
remove print
yubofredwang Jul 2, 2025
7a85c10
fix
yubofredwang Jul 3, 2025
96979ba
Merge branch 'main' into support-duplicate-cache
yubofredwang Jul 3, 2025
29a7f73
Merge branch 'main' into support-duplicate-cache
zhyncs Jul 6, 2025
972d9c7
Merge branch 'main' into support-duplicate-cache
yubofredwang Jul 7, 2025
4dc971d
Merge branch 'main' into support-duplicate-cache
yubofredwang Jul 16, 2025
54c7798
Merge branch 'sgl-project:main' into support-duplicate-cache
yubofredwang Jul 18, 2025
b8caa91
fix
yubofredwang Jul 18, 2025
76cbab3
generate correct
yubofredwang Jul 22, 2025
3cd2193
Merge
yubofredwang Jul 22, 2025
6ba9469
minor fix
yubofredwang Jul 26, 2025
1f61f50
fix batch
yubofredwang Jul 26, 2025
d355597
Merge branch 'sgl-project:main' into support-duplicate-cache
yubofredwang Jul 26, 2025
367c051
Merge remote-tracking branch 'refs/remotes/origin/support-duplicate-c…
yubofredwang Jul 26, 2025
d654fe1
fix
yubofredwang Jul 26, 2025
40ef116
Merge remote-tracking branch 'upstream/main' into support-duplicate-c…
yubofredwang Nov 16, 2025
222de79
fix merge
yubofredwang Nov 16, 2025
45c3e1e
tests done
yubofredwang Nov 17, 2025
2edb6be
Merge remote-tracking branch 'upstream/main' into support-duplicate-c…
yubofredwang Nov 17, 2025
48f272a
minor fix
yubofredwang Nov 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions python/sglang/srt/speculative/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,9 @@ def assign_draft_cache_locs(
extend_lens,
num_new_pages_per_topk,
out_cache_loc,
source_cache_loc,
target_cache_loc,
last_page_lens_cumsum,
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
Expand Down Expand Up @@ -776,39 +779,67 @@ def assign_draft_cache_locs(
if page_size == 1 or topk == 1:
return

# Part 2: Copy the indices for the last partial page
# Part 2: Copy indices into source_cache_loc and target_cache_loc
# Expected output: src:[8,9,10,8,9,10...] tgt:[16,17,18,24,25,26...]
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
offsets = tl.arange(0, page_size)
mask = offsets < last_page_len
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
prefix_base = token_pool + prefix_len - last_page_len

for topk_id in range(topk):
value = tl.load(prefix_base + offsets, mask=mask)
src_indices = tl.load(prefix_base + offsets, mask=mask)
last_page_lens_cumsum_ = tl.load(last_page_lens_cumsum + pid)
# Skip the first one since no copy is needed
for topk_id in range(1, topk):
tl.store(
source_cache_loc
+ (topk - 1) * (last_page_lens_cumsum_ - last_page_len)
+ (topk_id - 1) * last_page_len
+ offsets,
src_indices,
mask=mask,
)
tgt_indices = tl.load(
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
value,
mask=mask,
)

# Part 3: Remove the padding in out_cache_loc
iter_offest = tl.arange(0, iter_upper)
tl.store(
target_cache_loc
+ (topk - 1) * (last_page_lens_cumsum_ - last_page_len)
+ (topk_id - 1) * last_page_len
+ offsets,
tgt_indices,
mask=mask,
)
# Part 3: Copy and remove the used indices for duplication
# speculative_num_steps=5, page_size=4, num_new_pages_per_topk_=2, last_page_len=1
# - xxxxx .. | - xxxxx .. |
# topk=0 topk=1
# "-" means prefix tokens
# "x" means speculative draft tokens
# "." means padded tokens
# we only want to copy the "x" part.
iter_offset = tl.arange(0, iter_upper)
for topk_id in range(topk):
mask_upper = iter_offset < (speculative_num_steps + last_page_len)
mask_lower = iter_offset >= last_page_len
combined_mask = mask_upper & mask_lower
indices = tl.load(
prefix_base
+ topk_id * num_new_pages_per_topk_ * page_size
+ last_page_len
+ iter_offest,
mask=iter_offest < speculative_num_steps,
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + iter_offset,
mask=combined_mask,
other=0,
)
padding_len = (iter_upper - speculative_num_steps) * pid * topk
all_len = pid * num_new_pages_per_topk_ * page_size * topk
ptr_offset = all_len - padding_len
tl.store(
out_cache_loc
+ pid * topk * speculative_num_steps
- last_page_len
+ ptr_offset
+ topk_id * speculative_num_steps
+ iter_offest,
+ iter_offset,
indices,
mask=iter_offest < speculative_num_steps,
mask=combined_mask,
)


Expand Down
46 changes: 35 additions & 11 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,21 +427,13 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch):
# "x" means speculative draft tokens
# "." means padded tokens

# TODO(lmzheng): The current implementation is still a fake support
# for page size > 1. In the `assign_draft_cache_locs` below,
# we directly move the indices instead of the real kv cache.
# This only works when the kernel backend runs with page size = 1.
# If the kernel backend runs with page size > 1, we need to
# duplicate the real KV cache. The overhead of duplicating KV
# cache seems okay because the draft KV cache only has one layer.
# see a related copy operation in MHATokenToKVPool::move_kv_cache.

(
prefix_lens,
seq_lens,
last_loc,
self.num_new_pages_per_topk,
self.extend_lens,
last_page_lens,
) = get_last_loc_large_page_size_large_top_k(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
Expand All @@ -450,7 +442,6 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch):
self.topk,
self.page_size,
)

# TODO(lmzheng): remove this device sync
extend_num_tokens = torch.sum(self.extend_lens).item()

Expand All @@ -463,6 +454,22 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch):
backup_state=True,
)
)
if self.page_size > 1 and self.topk > 1:
last_page_lens_cumsum = torch.cumsum(last_page_lens, dim=0)
duplicate_cache_len = torch.sum(last_page_lens) * (self.topk - 1)
# TODO: Remove device sync here
target_cache_loc = torch.zeros(
duplicate_cache_len, dtype=torch.int32, device=self.device
)
source_cache_loc = torch.zeros(
duplicate_cache_len, dtype=torch.int32, device=self.device
)
else:
# When source_cache_loc is not needed, simply skip
duplicate_cache_len = 0
last_page_lens_cumsum = torch.empty(0, dtype=torch.int32, device=self.device)
source_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)
target_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device)

assign_draft_cache_locs[(num_seqs,)](
batch.req_pool_indices,
Expand All @@ -471,6 +478,9 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch):
self.extend_lens,
self.num_new_pages_per_topk,
out_cache_loc,
source_cache_loc,
target_cache_loc,
last_page_lens_cumsum,
batch.req_to_token_pool.req_to_token.shape[1],
self.topk,
self.speculative_num_steps,
Expand All @@ -480,6 +490,10 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch):
)

if self.page_size > 1 and self.topk > 1:
if duplicate_cache_len > 0:
self.draft_model_runner.token_to_kv_pool.move_kv_cache(
target_cache_loc, source_cache_loc
)
# Remove padded slots
out_cache_loc = out_cache_loc[
: num_seqs * self.topk * self.speculative_num_steps
Expand Down Expand Up @@ -533,6 +547,9 @@ def draft(self, batch: ScheduleBatch):
# Run forward steps
score_list, token_list, parents_list = self.draft_forward(forward_batch)

# Should this line be here?
# self.token_to_kv_pool_allocator.restore_state(self.token_to_kv_pool_state_backup)

if batch.forward_mode.is_idle():
return EagleVerifyInput.create_idle_input(
self.topk,
Expand Down Expand Up @@ -957,4 +974,11 @@ def get_last_loc_large_page_size_large_top_k(
prefix_lens,
)

return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
return (
prefix_lens,
seq_lens,
last_loc,
num_new_pages_per_topk,
extend_lens,
last_page_lens,
)
Loading
Loading