-
Notifications
You must be signed in to change notification settings - Fork 284
[feature] Add prefix_kv_cache transfer between dp rankers. #1093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…s to benchmark_sharegpt
Summary of ChangesHello @WANDY666, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial optimization for distributed inference by enabling the transfer of prefix KV cache data between different data parallel ranks. This feature aims to improve performance by allowing requests to leverage pre-computed KV states even when processed by different workers. The changes span across core components like memory management, request handling, and the distributed inference backend, supported by a new Triton kernel for efficient data movement and a fix for inter-process CUDA tensor transfers. Additionally, the benchmarking tools have been upgraded to provide more comprehensive performance analysis. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant feature to enable prefix KV cache sharing across data-parallel (DP) ranks, which is a complex but valuable optimization for distributed inference. The implementation is extensive, involving a new Triton kernel for peer-to-peer memory transfer, modifications to the memory manager, and sophisticated orchestration logic in the DP backend. A notable part of this change is the custom PyTorch reduction function for CUDA tensors to facilitate P2P communication. Additionally, the benchmark script has been substantially enhanced to support more realistic workloads and provide better metrics.
My review focuses on improving code clarity, fixing potential bugs, and identifying areas for refactoring. Key findings include a potential type error when reading an environment variable, some dead code in the benchmark script, and several opportunities to simplify and clean up the code for better maintainability.
| from lightllm.common.mem_manager import MemoryManager | ||
| import torch.multiprocessing as mp | ||
|
|
||
| min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 128) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
os.getenv returns a string value. The min_trans_token_num variable is later used in a numerical comparison with an integer (alloc_size < self.min_trans_token_num). To prevent potential type errors during runtime, the value from the environment variable should be explicitly cast to an integer.
| min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 128) | |
| min_trans_token_num = int(os.getenv("MIN_TRANS_TOKEN_NUM", "128")) |
| from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import ( | ||
| p2p_fix_rebuild_cuda_tensor, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| input_stride_0, | ||
| input_stride_1, | ||
| input_stride_2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These stride parameters (input_stride_1, input_stride_2) are either unused or their values are not used after being cast. This makes the kernel signature more complex than necessary. The same applies to output_stride_1 and output_stride_2 on lines 206-207. For improved clarity and maintainability, it's recommended to remove these unused parameters. Consequently, the call to this kernel in kv_trans_for_dp should be updated to pass only the required strides (e.g., output.stride(0)) instead of unpacking all strides with *output.stride().
lightllm/common/mem_manager.py
Outdated
| for i in range(0, len(mem_managers)): | ||
| mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This C-style for loop can be replaced with a more Pythonic for-each loop to improve readability and conciseness.
| for i in range(0, len(mem_managers)): | |
| mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) | |
| for mem_manager in mem_managers: | |
| mems_ptr.append(mem_manager.kv_buffer[layer_index, :, :, :].data_ptr()) |
lightllm/server/api_start.py
Outdated
| if not args.disable_dp_prompt_cache_fetch: | ||
| if args.run_mode != "normal" or args.dp <= 1: | ||
| args.disable_dp_prompt_cache_fetch = True | ||
| logger.warning( | ||
| """PD split mode or dp <= 1 does not support dp_prompt_cache_fetch; | ||
| overriding disable_dp_prompt_cache_fetch to True""" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested if statements can be combined into a single if statement with a compound condition. This will make the logic more concise and easier to read.
if not args.disable_dp_prompt_cache_fetch and (args.run_mode != "normal" or args.dp <= 1):
args.disable_dp_prompt_cache_fetch = True
logger.warning(
"""PD split mode or dp <= 1 does not support dp_prompt_cache_fetch;
overriding disable_dp_prompt_cache_fetch to True"""
)
lightllm/server/core/objs/req.py
Outdated
| def create_kv_indexes_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}" | ||
| self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64) | ||
| self.shm_kv_indexes.create_shm() | ||
| return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The function create_kv_indexes_shm_array is very similar to create_prompt_ids_shm_array and create_logprobs_shm_array. This code duplication can be reduced by creating a generic helper function to handle the creation of shared memory arrays. This would improve maintainability. A similar refactoring could be applied to the link_* functions.
| async for chunk, _ in response.content.iter_chunks(): | ||
| now_time = time.time() | ||
| delta_time = now_time - start_time | ||
| if is_first: | ||
| is_first = False | ||
| ttft = delta_time | ||
| text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") | ||
| if delta_time < 0.005: | ||
| receive_n += 1 | ||
| chunks.append(delta_time) | ||
| start_time = now_time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of processing server-sent events (SSE) by iterating over iter_chunks() and decoding each chunk individually can be fragile. A single event message might be split across multiple chunks, or one chunk could contain multiple messages. If a multi-byte UTF-8 character is split, chunk.decode('utf-8') will raise a UnicodeDecodeError. A more robust approach would be to buffer the incoming bytes and process complete messages, which are typically separated by \n\n in SSE.
| if delta_time < 0.005: | ||
| receive_n += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if delta_time < 0.005: | ||
| receive_n += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No description provided.