-
Notifications
You must be signed in to change notification settings - Fork 284
[feature] Add prefix_kv_cache transfer between dp rankers. #1093
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 7 commits
197039f
a24f296
a71444d
93f77de
112f0ee
dacca1b
7e28db3
38d4226
ff32ecc
1741884
5e7f2d9
fd0511e
fbd5c60
78892b8
b1938d0
14473db
93cb841
9a1a416
cf9cdb2
364e1ef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |||||||||
| import torch | ||||||||||
| import torch.distributed as dist | ||||||||||
| from typing import List, Union | ||||||||||
| from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp | ||||||||||
| from lightllm.server.pd_io_struct import KVMoveTask | ||||||||||
| from lightllm.utils.log_utils import init_logger | ||||||||||
| from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt | ||||||||||
|
|
@@ -93,7 +94,7 @@ def alloc_kv_move_buffer(self, max_req_total_len): | |||||||||
| """ | ||||||||||
| pd 分离模式使用的特殊接口 | ||||||||||
| """ | ||||||||||
| if isinstance(self, MemoryManager) and type(self) != MemoryManager: | ||||||||||
| if isinstance(self, MemoryManager) and type(self) is not MemoryManager: | ||||||||||
| raise NotImplementedError("subclass need reimpl this method") | ||||||||||
| self.kv_move_buffer = torch.empty( | ||||||||||
| (1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda" | ||||||||||
|
|
@@ -103,7 +104,7 @@ def alloc_kv_move_buffer(self, max_req_total_len): | |||||||||
| return | ||||||||||
|
|
||||||||||
| def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor: | ||||||||||
| if isinstance(self, MemoryManager) and type(self) != MemoryManager: | ||||||||||
| if isinstance(self, MemoryManager) and type(self) is not MemoryManager: | ||||||||||
| raise NotImplementedError("subclass need reimpl this method") | ||||||||||
|
|
||||||||||
| num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir) | ||||||||||
|
|
@@ -401,6 +402,35 @@ def get_index_kv_buffer(self, index): | |||||||||
| def load_index_kv_buffer(self, index, load_tensor_dict): | ||||||||||
| self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) | ||||||||||
|
|
||||||||||
| def copy_kv_from_other_dp_ranks( | ||||||||||
| self, | ||||||||||
| mem_managers: List["MemoryManager"], | ||||||||||
| move_token_indexes: torch.Tensor, | ||||||||||
| token_dp_indexes: torch.Tensor, | ||||||||||
| mem_indexes: torch.Tensor, | ||||||||||
| dp_size_in_node: int, | ||||||||||
| rank_in_dp: int, | ||||||||||
| ): | ||||||||||
| if not hasattr(self, "mem_ptrs_dict"): | ||||||||||
| self.mem_ptrs_dict = {} | ||||||||||
| for layer_index in range(self.layer_num): | ||||||||||
| mems_ptr = [] | ||||||||||
| for i in range(0, len(mem_managers)): | ||||||||||
| mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) | ||||||||||
|
||||||||||
| for i in range(0, len(mem_managers)): | |
| mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr()) | |
| for mem_manager in mem_managers: | |
| mems_ptr.append(mem_manager.kv_buffer[layer_index, :, :, :].data_ptr()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -269,6 +269,15 @@ def normal_or_p_d_start(args): | |
| args.router_max_wait_tokens = 0 | ||
|
|
||
| send_and_receive_node_ip(args) # 多机用于收发node ip | ||
| # PD 分离模式下必须禁用 DP prompt cache fetch,且 dp 必须 > 1 | ||
| if not args.disable_dp_prompt_cache_fetch: | ||
| if args.run_mode != "normal" or args.dp <= 1: | ||
| args.disable_dp_prompt_cache_fetch = True | ||
| logger.warning( | ||
| """PD split mode or dp <= 1 does not support dp_prompt_cache_fetch; | ||
| overriding disable_dp_prompt_cache_fetch to True""" | ||
| ) | ||
|
||
|
|
||
| set_env_start_args(args) | ||
| logger.info(f"all start args:{args}") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -121,6 +121,10 @@ class Req(ctypes.Structure): | |
| ("cpu_cache_match_page_indexes", CpuCachePageList), | ||
| # 分块hash的块大小 | ||
| ("cpu_cache_token_page_size", ctypes.c_int), | ||
| # 所有DP中的最大kv cache的长度 | ||
| ("dp_max_kv_len", ctypes.c_int), | ||
| # 拥有最大kv cache长度的dp_rank | ||
| ("dp_max_kv_rank", ctypes.c_int), | ||
| ] | ||
|
|
||
| def get_str(self): | ||
|
|
@@ -171,6 +175,7 @@ def init( | |
| self.alloc_shm_numpy_len = self.input_len + self.sample_params.max_new_tokens + 1024 # + 1024 for safe | ||
| self.create_logprobs_shm_array() | ||
| self.create_prompt_ids_shm_array() | ||
| self.create_kv_indexes_shm_array() | ||
| self.chunked_prefill_size = chunked_prefill_size | ||
| self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids | ||
| self.mtp_accepted_token_num = 0 | ||
|
|
@@ -181,6 +186,9 @@ def init( | |
| self.post_init() | ||
|
|
||
| self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size | ||
| # 初始化DP模式相关字段 | ||
| self.dp_max_kv_len = 0 | ||
| self.dp_max_kv_rank = -1 | ||
| if get_env_start_args().enable_cpu_cache: | ||
| self._fill_input_token_hash() | ||
| return | ||
|
|
@@ -225,12 +233,32 @@ def link_logprobs_shm_array(self): | |
| self.shm_logprobs.link_shm() | ||
| return | ||
|
|
||
| def create_kv_indexes_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}" | ||
| self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64) | ||
| self.shm_kv_indexes.create_shm() | ||
| return | ||
|
||
|
|
||
| def link_kv_indexes_shm_array(self): | ||
| service_uni_name = get_unique_server_name() | ||
| name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}" | ||
| self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64) | ||
| self.shm_kv_indexes.link_shm() | ||
| return | ||
|
|
||
| def get_prompt_ids(self): | ||
| return self.shm_prompt_ids.arr[: self.input_len].tolist() | ||
|
|
||
| def get_prompt_ids_numpy(self): | ||
| return self.shm_prompt_ids.arr[: self.input_len] | ||
|
|
||
| def get_kv_indexes(self): | ||
| return self.shm_kv_indexes.arr[: self.input_len].tolist() | ||
|
|
||
| def get_kv_indexes_numpy(self): | ||
| return self.shm_kv_indexes.arr[: self.input_len] | ||
|
|
||
| def to_router_rpc_obj(self): | ||
| if hasattr(self, "multimodal_params"): | ||
| return ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These stride parameters (
input_stride_1,input_stride_2) are either unused or their values are not used after being cast. This makes the kernel signature more complex than necessary. The same applies tooutput_stride_1andoutput_stride_2on lines 206-207. For improved clarity and maintainability, it's recommended to remove these unused parameters. Consequently, the call to this kernel inkv_trans_for_dpshould be updated to pass only the required strides (e.g.,output.stride(0)) instead of unpacking all strides with*output.stride().