Skip to content

Conversation

@WANDY666
Copy link
Contributor

@WANDY666 WANDY666 commented Nov 4, 2025

No description provided.

@gemini-code-assist
Copy link

Summary of Changes

Hello @WANDY666, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial optimization for distributed inference by enabling the transfer of prefix KV cache data between different data parallel ranks. This feature aims to improve performance by allowing requests to leverage pre-computed KV states even when processed by different workers. The changes span across core components like memory management, request handling, and the distributed inference backend, supported by a new Triton kernel for efficient data movement and a fix for inter-process CUDA tensor transfers. Additionally, the benchmarking tools have been upgraded to provide more comprehensive performance analysis.

Highlights

  • KV Cache Transfer: Introduced a new Triton kernel (_kv_trans_for_dp_kernel) and a corresponding Python function (kv_trans_for_dp) to efficiently transfer KV cache data between different data parallel (DP) ranks.
  • Memory Management Integration: The MemoryManager now includes a copy_kv_from_other_dp_ranks method, enabling the system to copy prefix KV cache from other DP ranks to the current rank's memory.
  • Request Object Enhancements: The Req object has been extended with dp_max_kv_len and dp_max_kv_rank fields to track the maximum KV cache length and the DP rank possessing it, facilitating prefix cache fetching. It also now manages shared memory arrays for KV indexes.
  • Data Parallel Backend Logic: The DPChunkedPrefillBackend has been updated to support fetching prefix KV cache from other DP ranks, including logic for matching radix cache, coordinating across ranks using dist.barrier(), and performing the KV transfer.
  • CUDA P2P Fix: A new module p2p_fix.py was added to address issues with CUDA tensor serialization and deserialization across processes, which is critical for enabling peer-to-peer (P2P) data transfers between GPUs in a distributed setup.
  • CLI Argument for Control: A new command-line argument --disable_dp_prompt_cache_fetch has been added to allow users to disable the prefix prompt cache fetching feature for data parallel inference.
  • Benchmarking Tool Improvements: The benchmark_sharegpt.py script has been significantly enhanced to support OpenAI API format, concurrency control, and more detailed latency metrics including time to first token and inter-token latency.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature to enable prefix KV cache sharing across data-parallel (DP) ranks, which is a complex but valuable optimization for distributed inference. The implementation is extensive, involving a new Triton kernel for peer-to-peer memory transfer, modifications to the memory manager, and sophisticated orchestration logic in the DP backend. A notable part of this change is the custom PyTorch reduction function for CUDA tensors to facilitate P2P communication. Additionally, the benchmark script has been substantially enhanced to support more realistic workloads and provide better metrics.

My review focuses on improving code clarity, fixing potential bugs, and identifying areas for refactoring. Key findings include a potential type error when reading an environment variable, some dead code in the benchmark script, and several opportunities to simplify and clean up the code for better maintainability.

from lightllm.common.mem_manager import MemoryManager
import torch.multiprocessing as mp

min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 128)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

os.getenv returns a string value. The min_trans_token_num variable is later used in a numerical comparison with an integer (alloc_size < self.min_trans_token_num). To prevent potential type errors during runtime, the value from the environment variable should be explicitly cast to an integer.

Suggested change
min_trans_token_num = os.getenv("MIN_TRANS_TOKEN_NUM", 128)
min_trans_token_num = int(os.getenv("MIN_TRANS_TOKEN_NUM", "128"))

Comment on lines 123 to 125
from lightllm.server.router.model_infer.mode_backend.continues_batch.pd_mode.p2p_fix import (
p2p_fix_rebuild_cuda_tensor,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This import is unnecessary because the function p2p_fix_rebuild_cuda_tensor is defined in this same file. Furthermore, the import path seems incorrect as it points to a different directory (.../continues_batch/pd_mode/...). This should be removed to avoid confusion and potential import errors.

Comment on lines +199 to +201
input_stride_0,
input_stride_1,
input_stride_2,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These stride parameters (input_stride_1, input_stride_2) are either unused or their values are not used after being cast. This makes the kernel signature more complex than necessary. The same applies to output_stride_1 and output_stride_2 on lines 206-207. For improved clarity and maintainability, it's recommended to remove these unused parameters. Consequently, the call to this kernel in kv_trans_for_dp should be updated to pass only the required strides (e.g., output.stride(0)) instead of unpacking all strides with *output.stride().

Comment on lines 418 to 419
for i in range(0, len(mem_managers)):
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This C-style for loop can be replaced with a more Pythonic for-each loop to improve readability and conciseness.

Suggested change
for i in range(0, len(mem_managers)):
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
for mem_manager in mem_managers:
mems_ptr.append(mem_manager.kv_buffer[layer_index, :, :, :].data_ptr())

Comment on lines 273 to 279
if not args.disable_dp_prompt_cache_fetch:
if args.run_mode != "normal" or args.dp <= 1:
args.disable_dp_prompt_cache_fetch = True
logger.warning(
"""PD split mode or dp <= 1 does not support dp_prompt_cache_fetch;
overriding disable_dp_prompt_cache_fetch to True"""
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested if statements can be combined into a single if statement with a compound condition. This will make the logic more concise and easier to read.

    if not args.disable_dp_prompt_cache_fetch and (args.run_mode != "normal" or args.dp <= 1):
        args.disable_dp_prompt_cache_fetch = True
        logger.warning(
            """PD split mode or dp <= 1 does not support dp_prompt_cache_fetch;
            overriding disable_dp_prompt_cache_fetch to True"""
        )

Comment on lines 236 to 241
def create_kv_indexes_shm_array(self):
service_uni_name = get_unique_server_name()
name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}"
self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64)
self.shm_kv_indexes.create_shm()
return

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function create_kv_indexes_shm_array is very similar to create_prompt_ids_shm_array and create_logprobs_shm_array. This code duplication can be reduced by creating a generic helper function to handle the creation of shared memory arrays. This would improve maintainability. A similar refactoring could be applied to the link_* functions.

Comment on lines +230 to +240
async for chunk, _ in response.content.iter_chunks():
now_time = time.time()
delta_time = now_time - start_time
if is_first:
is_first = False
ttft = delta_time
text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "")
if delta_time < 0.005:
receive_n += 1
chunks.append(delta_time)
start_time = now_time

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation of processing server-sent events (SSE) by iterating over iter_chunks() and decoding each chunk individually can be fragile. A single event message might be split across multiple chunks, or one chunk could contain multiple messages. If a multi-byte UTF-8 character is split, chunk.decode('utf-8') will raise a UnicodeDecodeError. A more robust approach would be to buffer the incoming bytes and process complete messages, which are typically separated by \n\n in SSE.

Comment on lines +237 to +238
if delta_time < 0.005:
receive_n += 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable receive_n is incremented here, but its value is never used. This, along with its initialization on line 222, appears to be dead code and should be removed to improve clarity.

Comment on lines +273 to +274
if delta_time < 0.005:
receive_n += 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable receive_n is incremented here, but its value is never used. This, along with its initialization on line 261, appears to be dead code and should be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants