Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 149 additions & 1 deletion lightllm/common/basemodel/infer_struct.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import torch
import triton
import collections
from lightllm.common.mem_manager import MemoryManager
from lightllm.common.req_manager import ReqManager
from lightllm.distributed import CustomProcessGroup
from typing import Tuple, Any, Optional
from typing import Tuple, Any, Optional, List
from .triton_kernel.gen_prefill_params import gen_prefill_params
from .triton_kernel.gen_decode_params import gen_decode_params
from .triton_kernel.multimodal_emb import mark_multimodal_obj
from .batch_objs import ModelInput
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.dist_utils import get_global_dp_rank


class InferStateInfo:
Expand Down Expand Up @@ -69,6 +73,18 @@ def __init__(self):
# 的输入会用到,其他模型和场景都不会用到
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None

# 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象,
# 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的
# dp,计算完成后,再 all to all 回去,这样可以使,各个dp 间处理的数据比较均衡,提升
# prefill时候的计算效率。下面的变量,都是在这种场景下才会被使用的变量,普通情况下
# 下面的变量不会被使用。
self.need_dp_prefill_balance: bool = False
self.dp_origin_lens: List[int] = None
self.dp_handle_lens: List[int] = None
# self.dp_input_lens: torch.Tensor = None
self.dp_output_split_sizes: List[List[int]] = None
self.dp_input_split_sizes: List[List[int]] = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
if self.is_prefill:
(
Expand Down Expand Up @@ -123,3 +139,135 @@ def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor):
for mark, obj in zip(marks_array, multi_objs):
obj["_prefill_"] = mark > 0
return

def prefill_dp_balance(self, input_ids: torch.Tensor):
"""
在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致
的prefill 推理性能下降
"""
assert self.is_prefill
import torch.distributed as dist

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import import torch.distributed as dist is local to the prefill_dp_balance method. It's better to move it to the top of the file for consistency and to avoid repeated import overhead. The same applies to other local imports in _all_to_all_balance_get and _all_to_all_unbalance_get.


args = get_env_start_args()

dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32)
input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32)
input_len.fill_(len(input_ids))
dist.all_gather_into_tensor(
output_tensor=dp_input_lens,
input_tensor=input_len,
group=self.dist_group.dp_prefill_balance_group,
async_op=False,
)
dp_input_lens = dp_input_lens.detach().cpu()
self.dp_origin_lens = dp_input_lens.tolist()
sum_input_len = dp_input_lens.sum().item()
dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)]
for i in range(sum_input_len % args.dp):
dp_handle_lens[i] += 1

self.dp_handle_lens = dp_handle_lens
# 分配每个dp 的原始输入和分配后的原始输入
origin_datas = collections.deque()
for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .numpy() inside a loop can be inefficient as it may cause a GPU-to-CPU synchronization on each iteration. It's better to call it once before the loop and iterate over the resulting numpy array.

origin_datas.append((origin_dp_index, 0, origin_dp_input_len))

dest_dp_inputs = []
for dest_dp_index in range(args.dp):
dest_dp_data = []
need_size = dp_handle_lens[dest_dp_index]
while len(origin_datas) != 0:
origin_data = origin_datas.popleft()
origin_dp_index, start, end = origin_data
if end - start > need_size:
dest_dp_data.append((origin_dp_index, start, start + need_size))
origin_datas.appendleft((origin_dp_index, start + need_size, end))
break
else:
dest_dp_data.append((origin_dp_index, start, end))
need_size -= end - start
if need_size == 0:
break

dest_dp_inputs.append(dest_dp_data)

dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
for origin_dp_index, start, end in dest_dp_data:
dp_output_split_sizes[dest_dp_index][origin_dp_index] = end - start
dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
for origin_dp_index, start, end in dest_dp_data:
dp_input_split_sizes[origin_dp_index][dest_dp_index] = end - start

self.dp_input_split_sizes = dp_input_split_sizes
self.dp_output_split_sizes = dp_output_split_sizes

new_input_ids = self._all_to_all_balance_get(input_ids)
if hasattr(self, "position_ids") and self.position_ids is not None:
self.position_ids = self._all_to_all_balance_get(self.position_ids)
if hasattr(self, "position_cos") and self.position_cos is not None:
self.position_cos = self._all_to_all_balance_get(self.position_cos)
if hasattr(self, "position_sin") and self.position_sin is not None:
self.position_sin = self._all_to_all_balance_get(self.position_sin)

return new_input_ids

def _all_to_all_balance_get(self, data: torch.Tensor):
dp_rank = get_global_dp_rank()
import torch.distributed as dist
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager

old_shape = data.shape
data = data.view(-1)

origin_len = self.dp_origin_lens[dp_rank]
assert data.shape[0] % origin_len == 0
scale_size = data.shape[0] // origin_len
handle_len = self.dp_handle_lens[dp_rank]

dest_data = g_cache_manager.alloc_tensor(
shape=(handle_len * scale_size,),
data_type=data.dtype,
device="cuda",
is_graph_out=False,
microbatch_index=self.microbatch_index,
)
dist.all_to_all_single(
output=dest_data.view(-1),
input=data.view(-1),
output_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]],
input_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]],
group=self.dist_group.dp_prefill_balance_group,
async_op=False,
)
return dest_data.view(-1, *old_shape[1:])

def _all_to_all_unbalance_get(self, data: torch.Tensor):
dp_rank = get_global_dp_rank()
import torch.distributed as dist
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager

old_shape = data.shape
data = data.view(-1)

handle_len = self.dp_handle_lens[dp_rank]
scale_size = data.shape[0] // handle_len
assert data.shape[0] % handle_len == 0
origin_len = self.dp_origin_lens[dp_rank]
origin_data = g_cache_manager.alloc_tensor(
shape=(origin_len * scale_size,),
data_type=data.dtype,
device="cuda",
is_graph_out=False,
microbatch_index=self.microbatch_index,
)
dist.all_to_all_single(
output=origin_data.view(-1),
input=data,
output_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]],
input_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]],
group=self.dist_group.dp_prefill_balance_group,
async_op=False,
)
return origin_data.view(-1, *old_shape[1:])
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv
from lightllm.distributed import all_reduce
from typing import Tuple
from lightllm.utils.envs_utils import get_env_start_args


class TransformerLayerInferTpl(TransformerLayerInfer):
Expand All @@ -31,8 +32,14 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
raise Exception("need to impl")

def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
shape = infer_state.kv_buffer_shapedtype[0]
shape = (len(infer_state.position_ids), *shape[1:])
else:
shape = infer_state.kv_buffer_shapedtype[0]

cache_kv = self.alloc_tensor(
shape=infer_state.kv_buffer_shapedtype[0],
shape=shape,
dtype=infer_state.kv_buffer_shapedtype[1],
device="cuda",
is_graph_out=False,
Expand Down
6 changes: 6 additions & 0 deletions lightllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_global_rank,
get_current_rank_in_dp,
create_new_group_for_current_dp,
create_dp_special_inter_group,
)
from lightllm.utils.device_utils import get_device_sm_count
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL
Expand All @@ -62,6 +63,11 @@ def __init__(self):
self.custom_gather = None
self.dp_world_size = get_dp_world_size()
self.device_group = create_new_group_for_current_dp("nccl")
if get_env_start_args().dp > 1 and get_env_start_args().enable_dp_prefill_balance:
self.dp_prefill_balance_group = create_dp_special_inter_group("nccl")
else:
self.dp_prefill_balance_group = None

self.autotune_group = dist.new_group([i for i in range(get_global_world_size())], backend="gloo")

def init_custom_reduce(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ def _tpsp_get_o(
dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_
o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device)
layer_weight.o_weight_.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :])
e_o_tensor = o_tensor[len(infer_state.position_cos) :, :]
if e_o_tensor.shape[0] > 0:
e_o_tensor.fill_(0)

if self.tp_world_size_ > 1:
sp_token_num = o_tensor.shape[0] // self.tp_world_size_
reduce_o_tensor = self.alloc_tensor((sp_token_num, self.embed_dim_), dtype=input.dtype, device=input.device)
Expand Down
10 changes: 10 additions & 0 deletions lightllm/models/llama/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightllm.common.basemodel import PostLayerInferTpl
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.distributed.communication_op import all_gather
from lightllm.utils.envs_utils import get_env_start_args


class LlamaPostLayerInfer(PostLayerInferTpl):
Expand Down Expand Up @@ -116,6 +117,9 @@ def tpsp_token_forward(
# len(infer_state.position_sin) 获取真实输入长度
input_embdings = gather_data[0 : len(infer_state.position_sin)]

if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings)

return self.token_forward(input_embdings=input_embdings, infer_state=infer_state, layer_weight=layer_weight)

def overlap_tpsp_token_forward(
Expand All @@ -130,12 +134,18 @@ def overlap_tpsp_token_forward(
infer_state.hook()
infer_state.hook = None

if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings)

logics = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight)

if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

if infer_state1.is_prefill and get_env_start_args().enable_dp_prefill_balance:
input_embdings1 = infer_state1._all_to_all_unbalance_get(data=input_embdings1)

logics1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight)

return logics, logics1
9 changes: 9 additions & 0 deletions lightllm/models/llama/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.models.llama.triton_kernel.embedding import embedding
from lightllm.distributed.communication_op import all_reduce
from lightllm.utils.envs_utils import get_env_start_args


class LlamaPreLayerInfer(PreLayerInferTpl):
Expand Down Expand Up @@ -42,6 +43,9 @@ def token_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weigh
def tpsp_context_forward(
self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight
):
if get_env_start_args().enable_dp_prefill_balance:
input_ids = infer_state.prefill_dp_balance(input_ids=input_ids)

input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight)
from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy

Expand Down Expand Up @@ -86,12 +90,17 @@ def overlap_tpsp_context_forward(
infer_state1: LlamaInferStateInfo,
layer_weight: LlamaPreAndPostLayerWeight,
):
if get_env_start_args().enable_dp_prefill_balance:
input_ids = infer_state.prefill_dp_balance(input_ids=input_ids)

input_embdings = self.context_forward(input_ids=input_ids, infer_state=infer_state, layer_weight=layer_weight)
from lightllm.common.basemodel.triton_kernel.sp_pad_copy import sp_pad_copy

padded_input_embdings = sp_pad_copy(input_embdings, sp_rank_id=self.tp_rank_, sp_world_size=self.tp_world_size_)

if get_env_start_args().enable_dp_prefill_balance:
input_ids1 = infer_state1.prefill_dp_balance(input_ids=input_ids1)

input_embdings1 = self.context_forward(
input_ids=input_ids1, infer_state=infer_state1, layer_weight=layer_weight
)
Expand Down
11 changes: 11 additions & 0 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ def _tpsp_get_qkv(
infer_state.position_cos,
infer_state.position_sin,
)

if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
q = infer_state._all_to_all_unbalance_get(data=q)
cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code performs an _all_to_all_unbalance_get before the attention calculation and an _all_to_all_balance_get after the attention calculation (in _tpsp_get_o). This introduces two all-to-all communication steps per layer, which can be a significant performance bottleneck. This seems to defeat the purpose of balancing the load for prefill, as the attention, a compute-heavy part, would run on unbalanced data. Please clarify the reasoning behind this design. If the attention kernel does not support the balanced data layout, this should be documented with a code comment explaining the limitation and the workaround.


return q, cache_kv

def _context_attention_flashinfer_kernel_fp8(
Expand Down Expand Up @@ -402,10 +407,16 @@ def _get_o(
def _tpsp_get_o(
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:
if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
input = infer_state._all_to_all_balance_get(data=input)

input = input.view(-1, self.tp_o_head_num_ * self.head_dim_)
dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_
o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device)
layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :])
e_o_tensor = o_tensor[len(infer_state.position_cos) :, :]
if e_o_tensor.shape[0] > 0:
e_o_tensor.fill_(0)

if self.tp_world_size_ > 1:
sp_token_num = o_tensor.shape[0] // self.tp_world_size_
Expand Down
6 changes: 5 additions & 1 deletion lightllm/models/qwen/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.functional as F
import torch.distributed as dist
import numpy as np

from typing import Tuple
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.qwen.layer_weights.transformer_layer_weight import QwenTransformerLayerWeight
Expand Down Expand Up @@ -32,3 +32,7 @@ def _get_qkv(self, input_emb, infer_state: QwenInferStateInfo, layer_weight: Qwe
if infer_state.logn_values is not None:
q.mul_(infer_state.logn_values.view(-1, 1))
return q, cache_kv

def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO
raise Exception("not impl")
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.functional as F
import torch.distributed as dist
import numpy as np
from typing import Tuple
from functools import partial

from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton
Expand Down Expand Up @@ -32,3 +33,7 @@ def _get_qkv(self, input, infer_state, layer_weight):
cache_kv[:, : self.tp_k_head_num_, :] = new_k.squeeze(0).permute(1, 0, 2)

return new_q, cache_kv

def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO
raise Exception("not impl")
4 changes: 4 additions & 0 deletions lightllm/models/qwen3/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ def _get_qkv(
infer_state.position_sin,
)
return q, cache_kv

def _tpsp_get_qkv(self, input, infer_state, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO
raise Exception("not impl")
Loading