Skip to content

Conversation

@yubofredwang
Copy link
Contributor

@yubofredwang yubofredwang commented Jul 2, 2025

Motivation

The current implementation is still a fake support for page size > 1. In the assign_draft_cache_locs. we directly move the indices instead of the real kv cache. This only works when the kernel backend runs with page size = 1. If the kernel backend runs with page size > 1, we need to duplicate the real KV cache. The overhead of duplicating KV cache seems okay because the draft KV cache only has one layer.

Modifications

Produce source_cache_loc and target_cache_loc in assign_draft_cache_locs, and call self.draft_model_runner.token_to_kv_pool.move_kv_cache(target_cache_loc, source_cache_loc). To copy the actual kv cache.

Example case:
Given: speculative_num_steps=5, page_size=4, topk=8, num_new_pages_per_topk_=2
assuming prefix is: [4,5,6,7,8,9,10], the last_page_len to be duplicated = 3, the generated source_cache_loc and target_cache_loc(len=21): [8,9,10] * 7
[8,9,10,8,9,10...]
target_cache_loc(len=21):
[16,17,18,24,25,26...]
out_cache_loc to be filled:
[11,12,13,14,15,19,20,21,22,23,24...]

Questions:

  1. Current implementation of move_kv_cache parlallize over layers, which is inefficient, the copy task in this PR can be parallelized over target_cache_loc dimensions.
  2. We have a device sync because we need to set duplicate_cache_len for the length of target_cache_loc. otherwise we have to copy all page_size * topk elements because we can't get the actual length without device sync.
  3. Discuss the restoration of the kv_token_pool.

Benchmark Result:

python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algo EAGLE \
    --speculative-draft lmsys/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \
    --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --dtype float16 --port 30000

Result
mt bench:
page_size=1
#questions: 20, Throughput: 696.78 token/s, Acceptance length: 4.38
page_size =8
#questions: 20, Throughput: 544.93 token/s, Acceptance length: 4.37

send_one:
page_size=1
acc_length=3.97

page_size=8
acc_length=3.5

Checklist

@yubofredwang yubofredwang changed the title Support duplicate cache Support duplicate cache in eagle draft model Jul 2, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @yubofredwang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the speculative decoding (EAGLE) algorithm by correctly implementing KV cache duplication for scenarios where page_size and topk are greater than one. Previously, the system only moved indices, which was insufficient for kernel backends with larger page sizes. The changes introduce a robust mechanism to identify and copy relevant KV cache entries, ensuring the accuracy and integrity of the speculative decoding process under more complex memory management schemes. This improves the overall correctness and reliability of the EAGLE implementation.

Highlights

  • KV Cache Duplication: Implemented the actual duplication of KV cache entries for speculative decoding when page_size > 1 and topk > 1, moving from a previous 'fake' index-only copy to a real KV cache content copy. This ensures correct behavior for these configurations.
  • assign_draft_cache_locs Enhancement: Modified the assign_draft_cache_locs Triton kernel to generate source_cache_loc and target_cache_loc arrays. These arrays precisely map the locations from which KV cache data should be copied and to where it should be duplicated.
  • Integration of KV Cache Copy: Integrated a crucial call to self.draft_model_runner.token_to_kv_pool.move_kv_cache within the _draft_preprocess_decode method. This call executes the actual KV cache duplication based on the source and target locations prepared by the assign_draft_cache_locs kernel.
  • New Unit Tests: Added comprehensive unit tests in test_eagle_utils.py to validate the correctness of the assign_draft_cache_locs function and the end-to-end KV cache duplication process for both single and multi-sequence scenarios, ensuring the new logic works as expected.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces crucial support for duplicating KV cache when page_size > 1 in the EAGLE speculative decoding model, which was previously a known limitation. The changes involve modifying the assign_draft_cache_locs function to handle the new source_cache_loc, target_cache_loc, and last_page_lens_cumsum parameters, and adding a new Triton kernel copy_all_layer_kv_cache to perform the actual KV cache duplication. Comprehensive unit tests have been added to validate the new functionality for both single and multi-sequence scenarios, which is excellent for maintaining code quality. There are a few areas for improvement regarding performance and code clarity that are highlighted in the specific comments.

@yubofredwang
Copy link
Contributor Author

Latest testing result:

python3 benchmark/mtbench/bench_sglang_eagle.py --num-questions 80 --parallel 10 --question-file /shared/user/mtbench/question.jsonl

without CUDA graph

page_size = 4
#questions: 80, Throughput: 1168.43 token/s, Acceptance length: 2.66
page_size = 1
#questions: 80, Throughput: 1224.18 token/s, Acceptance length: 2.67

command:

SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 python3 -m sglang.launch_server --model /shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct/07eb05b21d191a58c577b4a45982fe0c049d0693 --speculative-algorithm EAGLE3 --speculative-draft-model-path /shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B/e5ed08d66f528a95ce89f5d4fd136a28f6def714 --speculative-num-steps 2 --speculative-eagle-topk 2 --speculative-num-draft-tokens 5 --mem-fraction 0.9 --page-size 128 --attention-backend fa3 --dtype bfloat16 --trust-remote-code --disable-cuda-graph

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants