Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions lightllm/common/kv_trans_kernel/kv_trans_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,102 @@ def kv_trans_v2_for_d_node(
num_warps=1,
)
return


@triton.jit
def _kv_trans_for_dp_kernel(
input_mems_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
Comment on lines +199 to +201

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These stride parameters (input_stride_1, input_stride_2) are either unused or their values are not used after being cast. This makes the kernel signature more complex than necessary. The same applies to output_stride_1 and output_stride_2 on lines 206-207. For improved clarity and maintainability, it's recommended to remove these unused parameters. Consequently, the call to this kernel in kv_trans_for_dp should be updated to pass only the required strides (e.g., output.stride(0)) instead of unpacking all strides with *output.stride().

input_token_idx_ptr,
input_token_dp_index_ptr,
output_ptr,
output_stride_0,
output_stride_1,
output_stride_2,
output_token_idx_ptr,
token_num: int,
head_num: int,
head_dim: int,
grid_count: int,
BLOCK_SIZE: tl.constexpr,
NUM_STAGES: tl.constexpr,
CARD_NUM_PER_D: tl.constexpr,
RANK_IN_DP: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
output_stride_1 = tl.cast(output_stride_1, dtype=tl.int64)

head_num_dim = head_num * head_dim
tid = tl.program_id(0)

offs = tl.arange(0, BLOCK_SIZE)
while tid < token_num:
dp_index = tl.load(input_token_dp_index_ptr + tid)
mem_index = RANK_IN_DP + dp_index * CARD_NUM_PER_D
input_token_idx = tl.load(input_token_idx_ptr + tid)
output_token_idx = tl.load(output_token_idx_ptr + tid)
for block_idx in tl.range(0, tl.cdiv(head_num_dim, BLOCK_SIZE), 1, num_stages=NUM_STAGES):
cur_offs = block_idx * BLOCK_SIZE + offs
input_ptr = tl.load(input_mems_ptr + mem_index).to(tl.pointer_type(output_ptr.dtype.element_ty))
in_datas = tl.load(input_ptr + input_stride_0 * input_token_idx + cur_offs, mask=cur_offs < head_num_dim)
tl.store(output_ptr + output_stride_0 * output_token_idx + cur_offs, in_datas, mask=cur_offs < head_num_dim)

tid += grid_count

return


def kv_trans_for_dp(
input_mems: torch.Tensor,
input_idx: torch.Tensor,
input_dp_idx: torch.Tensor,
output: torch.Tensor,
output_idx: torch.Tensor,
dp_size_in_node: int,
rank_in_dp: int,
):
"""
input_mems 是一个 torch.uint64 的tensor, 其内部存储了当前使用的对应的mem_manager对象中kv cache的首指针。
"""
assert input_mems.is_contiguous()
assert output.is_contiguous()
assert len(input_mems.shape) == 1
assert len(output.shape) == 3
assert len(input_idx) == len(output_idx)
assert len(output_idx) == len(input_dp_idx)
assert len(input_mems) % dp_size_in_node == 0

card_num_per_d = len(input_mems) // dp_size_in_node

_, head_num, head_dim = output.shape
token_num = len(output_idx)
# 用较少的资源来做数据传输,防止占用过多的 sm 计算单元
grid_count = 20
BLOCK_SIZE = 256
NUM_STAGES = 3
grid = (grid_count,)

_kv_trans_for_dp_kernel[grid](
input_mems,
*output.stride(),
input_idx,
input_dp_idx,
output,
*output.stride(),
output_idx,
token_num=token_num,
head_num=head_num,
head_dim=head_dim,
grid_count=grid_count,
BLOCK_SIZE=BLOCK_SIZE,
NUM_STAGES=NUM_STAGES,
CARD_NUM_PER_D=card_num_per_d,
RANK_IN_DP=rank_in_dp,
num_warps=1,
)

return
34 changes: 32 additions & 2 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.distributed as dist
from typing import List, Union
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp
from lightllm.server.pd_io_struct import KVMoveTask
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
Expand Down Expand Up @@ -93,7 +94,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
"""
pd 分离模式使用的特殊接口
"""
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
if isinstance(self, MemoryManager) and type(self) is not MemoryManager:
raise NotImplementedError("subclass need reimpl this method")
self.kv_move_buffer = torch.empty(
(1, max_req_total_len + 8, 2 * self.head_num, self.head_dim), dtype=self.dtype, device="cuda"
Expand All @@ -103,7 +104,7 @@ def alloc_kv_move_buffer(self, max_req_total_len):
return

def alloc_paged_kv_move_buffer(self, page_num, page_size) -> torch.Tensor:
if isinstance(self, MemoryManager) and type(self) != MemoryManager:
if isinstance(self, MemoryManager) and type(self) is not MemoryManager:
raise NotImplementedError("subclass need reimpl this method")

num_kv_head = get_num_key_value_heads(get_env_start_args().model_dir)
Expand Down Expand Up @@ -401,6 +402,35 @@ def get_index_kv_buffer(self, index):
def load_index_kv_buffer(self, index, load_tensor_dict):
self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"])

def copy_kv_from_other_dp_ranks(
self,
mem_managers: List["MemoryManager"],
move_token_indexes: torch.Tensor,
token_dp_indexes: torch.Tensor,
mem_indexes: torch.Tensor,
dp_size_in_node: int,
rank_in_dp: int,
):
if not hasattr(self, "mem_ptrs_dict"):
self.mem_ptrs_dict = {}
for layer_index in range(self.layer_num):
mems_ptr = []
for i in range(0, len(mem_managers)):
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This C-style for loop can be replaced with a more Pythonic for-each loop to improve readability and conciseness.

Suggested change
for i in range(0, len(mem_managers)):
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
for mem_manager in mem_managers:
mems_ptr.append(mem_manager.kv_buffer[layer_index, :, :, :].data_ptr())

mems_ptr = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")
self.mem_ptrs_dict[layer_index] = mems_ptr

for layer_index in range(self.layer_num):
kv_trans_for_dp(
input_mems=self.mem_ptrs_dict[layer_index],
input_idx=move_token_indexes,
input_dp_idx=token_dp_indexes,
output=self.kv_buffer[layer_index],
output_idx=mem_indexes,
dp_size_in_node=dp_size_in_node,
rank_in_dp=rank_in_dp,
)


class ReadOnlyStaticsMemoryManager:
"""
Expand Down
7 changes: 7 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,4 +537,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
)
parser.add_argument(
"--disable_dp_prompt_cache_fetch",
action="store_true",
default=False,
help="""Disable prefix prompt cache fetch for data parallel inference.
Enabled by default, but currently not supported for pd separated mode""",
)
return parser
9 changes: 9 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,15 @@ def normal_or_p_d_start(args):
args.router_max_wait_tokens = 0

send_and_receive_node_ip(args) # 多机用于收发node ip
# PD 分离模式下必须禁用 DP prompt cache fetch,且 dp 必须 > 1
if not args.disable_dp_prompt_cache_fetch:
if args.run_mode != "normal" or args.dp <= 1:
args.disable_dp_prompt_cache_fetch = True
logger.warning(
"""PD split mode or dp <= 1 does not support dp_prompt_cache_fetch;
overriding disable_dp_prompt_cache_fetch to True"""
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The nested if statements can be combined into a single if statement with a compound condition. This will make the logic more concise and easier to read.

    if not args.disable_dp_prompt_cache_fetch and (args.run_mode != "normal" or args.dp <= 1):
        args.disable_dp_prompt_cache_fetch = True
        logger.warning(
            """PD split mode or dp <= 1 does not support dp_prompt_cache_fetch;
            overriding disable_dp_prompt_cache_fetch to True"""
        )


set_env_start_args(args)
logger.info(f"all start args:{args}")

Expand Down
28 changes: 28 additions & 0 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ class Req(ctypes.Structure):
("cpu_cache_match_page_indexes", CpuCachePageList),
# 分块hash的块大小
("cpu_cache_token_page_size", ctypes.c_int),
# 所有DP中的最大kv cache的长度
("dp_max_kv_len", ctypes.c_int),
# 拥有最大kv cache长度的dp_rank
("dp_max_kv_rank", ctypes.c_int),
]

def get_str(self):
Expand Down Expand Up @@ -171,6 +175,7 @@ def init(
self.alloc_shm_numpy_len = self.input_len + self.sample_params.max_new_tokens + 1024 # + 1024 for safe
self.create_logprobs_shm_array()
self.create_prompt_ids_shm_array()
self.create_kv_indexes_shm_array()
self.chunked_prefill_size = chunked_prefill_size
self.shm_prompt_ids.arr[0 : len(prompt_ids)] = prompt_ids
self.mtp_accepted_token_num = 0
Expand All @@ -181,6 +186,9 @@ def init(
self.post_init()

self.cpu_cache_token_page_size = get_env_start_args().cpu_cache_token_page_size
# 初始化DP模式相关字段
self.dp_max_kv_len = 0
self.dp_max_kv_rank = -1
if get_env_start_args().enable_cpu_cache:
self._fill_input_token_hash()
return
Expand Down Expand Up @@ -225,12 +233,32 @@ def link_logprobs_shm_array(self):
self.shm_logprobs.link_shm()
return

def create_kv_indexes_shm_array(self):
service_uni_name = get_unique_server_name()
name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}"
self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64)
self.shm_kv_indexes.create_shm()
return

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function create_kv_indexes_shm_array is very similar to create_prompt_ids_shm_array and create_logprobs_shm_array. This code duplication can be reduced by creating a generic helper function to handle the creation of shared memory arrays. This would improve maintainability. A similar refactoring could be applied to the link_* functions.


def link_kv_indexes_shm_array(self):
service_uni_name = get_unique_server_name()
name = f"{service_uni_name}_shm_kv_indexes_{self.index_in_shm_mem}"
self.shm_kv_indexes = ShmArray(name, (self.alloc_shm_numpy_len,), dtype=np.int64)
self.shm_kv_indexes.link_shm()
return

def get_prompt_ids(self):
return self.shm_prompt_ids.arr[: self.input_len].tolist()

def get_prompt_ids_numpy(self):
return self.shm_prompt_ids.arr[: self.input_len]

def get_kv_indexes(self):
return self.shm_kv_indexes.arr[: self.input_len].tolist()

def get_kv_indexes_numpy(self):
return self.shm_kv_indexes.arr[: self.input_len]

def to_router_rpc_obj(self):
if hasattr(self, "multimodal_params"):
return (
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class StartArgs:
cpu_cache_token_page_size: int = field(default=64)
enable_disk_cache: bool = field(default=False)
disk_cache_storage_size: float = field(default=10)
disable_dp_prompt_cache_fetch: bool = field(default=False)
# zmp ports
router_port: int = field(default=None)
detokenization_port: int = field(default=None)
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ async def wait_to_model_ready(self):
info_queue=self.info_queue,
mem_queue=self.mem_queues[(rank_id % node_world_size)],
router_lock=self.router_lock,
mem_queues=self.mem_queues,
)
)
tasks.append(task)
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/router/model_infer/infer_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def _init_all_state(self):
self.shm_req = g_infer_context.shm_req_manager.get_req_obj_by_index(self.shm_index)
self.shm_req.link_prompt_ids_shm_array()
self.shm_req.link_logprobs_shm_array()
self.shm_req.link_kv_indexes_shm_array()
self.sampling_param: InferSamplingParams = InferSamplingParams(self.shm_req, self.vocab_size)

# 更新 nixl pd 分离模式下, prefill 节点需要开始传输的起始位置
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@

class DPForDecodeNode(DPChunkedPrefillBackend):
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
super().__init__()
super().__init__(mem_queue=mem_queue)
self.info_queue: mp.Queue = info_queue
self.mem_queue: mp.Queue = mem_queue
self.classed_req_strict_prefill = False
return

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@

class DPChunkedForPrefillNode(DPChunkedPrefillBackend):
def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
super().__init__()
super().__init__(mem_queue=mem_queue)
self.support_overlap = False
self.info_queue: mp.Queue = info_queue
self.mem_queue: mp.Queue = mem_queue
self.classed_req_no_decode = True

def init_custom(self):
Expand Down
Loading