Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/CN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ PD 分离模式参数
.. option:: --dp

数据并行大小,默认为 ``1``

这是 deepseekv2 的有用参数。使用 deepseekv2 模型时,将 dp 设置为等于 tp 参数。
其他情况下请不要设置,保持默认值 1。

.. option:: --nccl_host

Expand Down
3 changes: 0 additions & 3 deletions docs/EN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ Different Parallel Mode Setting Parameters
.. option:: --dp

Data parallelism size, default is ``1``

This is a useful parameter for deepseekv2. When using deepseekv2 model, set dp equal to the tp parameter.
In other cases, please do not set it, keep the default value of 1.

.. option:: --nccl_host

Expand Down
147 changes: 120 additions & 27 deletions lightllm/common/mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os
import torch
import torch.distributed as dist
from typing import List, Union
from typing import List, Union, Optional
from lightllm.server.pd_io_struct import KVMoveTask
from lightllm.utils.log_utils import init_logger
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
from lightllm.utils.dist_utils import get_current_rank_in_node
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
from lightllm.distributed.pynccl import PyNcclCommunicator
Expand Down Expand Up @@ -103,7 +104,10 @@ def send_to_decode_node(
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1
if dp_size_in_node > 1:
return self.send_to_decode_node_p2p(
move_tasks, mem_managers, dp_size_in_node, nccl_comm
)

# 先将数据发送到指定的一张卡上的buffer,再发送。

Expand Down Expand Up @@ -143,8 +147,10 @@ def receive_from_prefill_node(
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

if dp_size_in_node > 1:
return self.receive_from_prefill_node_p2p(
move_tasks, mem_managers, dp_size_in_node, nccl_comm
)
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。

move_token_indexes = []
Expand Down Expand Up @@ -183,29 +189,73 @@ def send_to_decode_node_p2p(
"""
使用 p2p triton kernel 进行数据复制和传输的实现方式。
"""
assert dp_size_in_node == 1

# 先将数据发送到指定的一张卡上的buffer,再发送。
if dp_size_in_node > 1:
mem_ptrs_dict = {}
# 一个 dp 副本占用的 device 数量
group_stride = max(1, len(mem_managers) // dp_size_in_node)
for layer_index in range(self.layer_num):
mems_ptr = []
for i in range(0, len(mem_managers), group_stride):
mems_ptr.append(mem_managers[i].kv_buffer[layer_index, :, :, :].data_ptr())
mem_ptrs_dict[layer_index] = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")

move_token_indexes = []
token_dp_indexes = []
for task in move_tasks:
if task.move_kv_len != 0:
move_token_indexes.extend(task.prefill_token_indexes[-task.move_kv_len :])
if dp_size_in_node > 1:
token_dp_indexes.extend([task.prefill_dp_index for _ in range(task.move_kv_len)])

if len(move_token_indexes) == 0:
return

move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
for i, mem in enumerate(mem_managers):
for layer_index in range(mem.layer_num):
move_buffer = mem._get_kv_move_data_p2p(move_token_indexes, layer_index, self.kv_move_buffer)
nccl_comm.send(move_buffer, dst=1)
token_dp_tensor = (
torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") if dp_size_in_node > 1 else None
)

for layer_index in range(self.layer_num):
move_buffer = self._get_kv_move_data_p2p(
move_token_indexes,
layer_index,
self.kv_move_buffer,
token_dp_indexes=token_dp_tensor,
dp_size_in_node=dp_size_in_node,
mem_ptrs_dict=mem_ptrs_dict
)
nccl_comm.send(move_buffer, dst=1)
return

def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, kv_move_buffer: torch.Tensor):
def _get_kv_move_data_p2p(
self,
token_indexes: torch.Tensor,
layer_index: int,
kv_move_buffer: torch.Tensor,
token_dp_indexes: Optional[torch.Tensor] = None,
dp_size_in_node: int = 1,
mem_ptrs_dict: Optional[dict] = None
):
move_token_num = len(token_indexes)
move_size = self.token_dim_size * move_token_num
move_buffer = kv_move_buffer.view(-1)[0:move_size].view(move_token_num, 2 * self.head_num, self.head_dim)
kv_trans(
self.kv_buffer[layer_index, :, :, :], token_indexes, move_buffer, self.kv_move_buf_indexes[0:move_token_num]
)

if dp_size_in_node == 1 or token_dp_indexes is None:
kv_trans(
self.kv_buffer[layer_index, :, :, :],
token_indexes,
move_buffer,
self.kv_move_buf_indexes[0:move_token_num],
)
else:
kv_trans_v2_for_p_node(
input_mems=mem_ptrs_dict[layer_index],
input_idx=token_indexes,
input_dp_idx=token_dp_indexes,
output=move_buffer,
output_idx=self.kv_move_buf_indexes[0:move_token_num],
dp_size_in_node=dp_size_in_node,
)
return move_buffer

def receive_from_prefill_node_p2p(
Expand All @@ -215,29 +265,72 @@ def receive_from_prefill_node_p2p(
dp_size_in_node: int,
nccl_comm: PyNcclCommunicator,
):
assert dp_size_in_node == 1

# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
if dp_size_in_node > 1:
mem_ptrs_dict = {}
for layer_index in range(self.layer_num):
mems_ptr = []
for mem in mem_managers:
mems_ptr.append(mem.kv_buffer[layer_index, :, :, :].data_ptr())
mem_ptrs_dict[layer_index] = torch.tensor(mems_ptr, dtype=torch.uint64, device="cuda")

move_token_indexes = []
token_dp_indexes = []
for task in move_tasks:
if task.move_kv_len != 0:
move_token_indexes.extend(task.decode_token_indexes[-task.move_kv_len :])
if dp_size_in_node > 1:
token_dp_indexes.extend([task.decode_dp_index for _ in range(task.move_kv_len)])

if len(move_token_indexes) == 0:
return

move_token_indexes = torch.tensor(move_token_indexes, dtype=torch.int64, device="cuda")
token_dp_tensor = (
torch.tensor(token_dp_indexes, dtype=torch.int32, device="cuda") if dp_size_in_node > 1 else None
)

token_num = len(move_token_indexes)
move_size = self.token_dim_size * token_num
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(token_num, 2 * self.head_num, self.head_dim)
for i, mem in enumerate(mem_managers):
for layer_index in range(mem.layer_num):
nccl_comm.recv(recive_buffer, src=0)
mem._write_kv_move_data_p2p(move_token_indexes, recive_buffer, layer_index)
move_token_num = len(move_token_indexes)
move_size = self.token_dim_size * move_token_num
recive_buffer = self.kv_move_buffer.view(-1)[0:move_size].view(move_token_num, 2 * self.head_num, self.head_dim)

for layer_index in range(self.layer_num):
nccl_comm.recv(recive_buffer, src=0)
self._write_kv_move_data_p2p(
move_token_indexes,
recive_buffer,
layer_index,
token_dp_indexes=token_dp_tensor,
dp_size_in_node=dp_size_in_node,
mem_ptrs_dict=mem_ptrs_dict
)
return

def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: torch.Tensor, layer_index):
def _write_kv_move_data_p2p(
self,
token_indexes: torch.Tensor,
buffer_tensor: torch.Tensor,
layer_index: int,
token_dp_indexes: Optional[torch.Tensor] = None,
dp_size_in_node: int = 1,
mem_ptrs_dict: Optional[dict] = None
):
move_token_num = len(token_indexes)
kv_trans(buffer_tensor, self.kv_move_buf_indexes[0:move_token_num], self.kv_buffer[layer_index], token_indexes)
if dp_size_in_node == 1 or token_dp_indexes is None:
kv_trans(
buffer_tensor,
self.kv_move_buf_indexes[0:move_token_num],
self.kv_buffer[layer_index],
token_indexes,
)
else:
kv_trans_v2_for_d_node(
output_mems=mem_ptrs_dict[layer_index],
output_idx=token_indexes,
output_dp_idx=token_dp_indexes,
input=buffer_tensor,
input_idx=self.kv_move_buf_indexes[0:move_token_num],
dp_size_in_node=dp_size_in_node,
)
return

def _free_buffers(self):
Expand Down