-
Notifications
You must be signed in to change notification settings - Fork 284
tpsp mode support db prefill balance. #1086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
6d0376d
1d0da2d
36ecad8
8f23444
c47a76e
645baa1
b99deda
5ac8bfa
5d6fe53
08cc488
b75b3c1
6180581
815f5ac
5e2dfc9
f68a306
fbc1648
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,16 @@ | ||
| import torch | ||
| import triton | ||
| import collections | ||
| from lightllm.common.mem_manager import MemoryManager | ||
| from lightllm.common.req_manager import ReqManager | ||
| from lightllm.distributed import CustomProcessGroup | ||
| from typing import Tuple, Any, Optional | ||
| from typing import Tuple, Any, Optional, List | ||
| from .triton_kernel.gen_prefill_params import gen_prefill_params | ||
| from .triton_kernel.gen_decode_params import gen_decode_params | ||
| from .triton_kernel.multimodal_emb import mark_multimodal_obj | ||
| from .batch_objs import ModelInput | ||
| from lightllm.utils.envs_utils import get_env_start_args | ||
| from lightllm.utils.dist_utils import get_global_dp_rank | ||
|
|
||
|
|
||
| class InferStateInfo: | ||
|
|
@@ -69,6 +73,18 @@ def __init__(self): | |
| # 的输入会用到,其他模型和场景都不会用到 | ||
| self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None | ||
|
|
||
| # 在单节点多dp的运行模式下,在进行prefill的阶段,如果出现了dp之间数据不平衡的现象, | ||
| # 可以将推理的数据,进行重新分配到各个dp,在做 att 之前,重新 all to all 到各自的 | ||
| # dp,计算完成后,再 all to all 回去,这样可以使,各个dp 间处理的数据比较均衡,提升 | ||
| # prefill时候的计算效率。下面的变量,都是在这种场景下才会被使用的变量,普通情况下 | ||
| # 下面的变量不会被使用。 | ||
| self.need_dp_prefill_balance: bool = False | ||
| self.dp_origin_lens: List[int] = None | ||
| self.dp_handle_lens: List[int] = None | ||
| # self.dp_input_lens: torch.Tensor = None | ||
| self.dp_output_split_sizes: List[List[int]] = None | ||
| self.dp_input_split_sizes: List[List[int]] = None | ||
|
|
||
| def init_some_extra_state(self, model, input_ids: torch.Tensor): | ||
| if self.is_prefill: | ||
| ( | ||
|
|
@@ -123,3 +139,135 @@ def mark_multimodal_objs_for_prefill(self, input_ids: torch.Tensor): | |
| for mark, obj in zip(marks_array, multi_objs): | ||
| obj["_prefill_"] = mark > 0 | ||
| return | ||
|
|
||
| def prefill_dp_balance(self, input_ids: torch.Tensor): | ||
| """ | ||
| 在prefill的时候, 对于处于 dp 模式下的时候,对输入的数据进行重新的调整和分配,降低各个dp处理数据量过于不一致的时候,导致 | ||
| 的prefill 推理性能下降 | ||
| """ | ||
| assert self.is_prefill | ||
| import torch.distributed as dist | ||
|
|
||
| args = get_env_start_args() | ||
|
|
||
| dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32) | ||
| input_len = torch.empty(size=(1,), device="cuda", dtype=torch.int32) | ||
| input_len.fill_(len(input_ids)) | ||
| dist.all_gather_into_tensor( | ||
| output_tensor=dp_input_lens, | ||
| input_tensor=input_len, | ||
| group=self.dist_group.dp_prefill_balance_group, | ||
| async_op=False, | ||
| ) | ||
| dp_input_lens = dp_input_lens.detach().cpu() | ||
| self.dp_origin_lens = dp_input_lens.tolist() | ||
| sum_input_len = dp_input_lens.sum().item() | ||
| dp_handle_lens = [sum_input_len // args.dp for _ in range(args.dp)] | ||
| for i in range(sum_input_len % args.dp): | ||
| dp_handle_lens[i] += 1 | ||
|
|
||
| self.dp_handle_lens = dp_handle_lens | ||
| # 分配每个dp 的原始输入和分配后的原始输入 | ||
| origin_datas = collections.deque() | ||
| for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| origin_datas.append((origin_dp_index, 0, origin_dp_input_len)) | ||
|
|
||
| dest_dp_inputs = [] | ||
| for dest_dp_index in range(args.dp): | ||
| dest_dp_data = [] | ||
| need_size = dp_handle_lens[dest_dp_index] | ||
| while len(origin_datas) != 0: | ||
| origin_data = origin_datas.popleft() | ||
| origin_dp_index, start, end = origin_data | ||
| if end - start > need_size: | ||
| dest_dp_data.append((origin_dp_index, start, start + need_size)) | ||
| origin_datas.appendleft((origin_dp_index, start + need_size, end)) | ||
| break | ||
| else: | ||
| dest_dp_data.append((origin_dp_index, start, end)) | ||
| need_size -= end - start | ||
| if need_size == 0: | ||
| break | ||
|
|
||
| dest_dp_inputs.append(dest_dp_data) | ||
|
|
||
| dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)] | ||
| for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs): | ||
| for origin_dp_index, start, end in dest_dp_data: | ||
| dp_output_split_sizes[dest_dp_index][origin_dp_index] = end - start | ||
| dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)] | ||
| for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs): | ||
| for origin_dp_index, start, end in dest_dp_data: | ||
| dp_input_split_sizes[origin_dp_index][dest_dp_index] = end - start | ||
|
|
||
| self.dp_input_split_sizes = dp_input_split_sizes | ||
| self.dp_output_split_sizes = dp_output_split_sizes | ||
|
|
||
| new_input_ids = self._all_to_all_balance_get(input_ids) | ||
| if hasattr(self, "position_ids") and self.position_ids is not None: | ||
| self.position_ids = self._all_to_all_balance_get(self.position_ids) | ||
| if hasattr(self, "position_cos") and self.position_cos is not None: | ||
| self.position_cos = self._all_to_all_balance_get(self.position_cos) | ||
| if hasattr(self, "position_sin") and self.position_sin is not None: | ||
| self.position_sin = self._all_to_all_balance_get(self.position_sin) | ||
|
|
||
| return new_input_ids | ||
|
|
||
| def _all_to_all_balance_get(self, data: torch.Tensor): | ||
| dp_rank = get_global_dp_rank() | ||
| import torch.distributed as dist | ||
| from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager | ||
|
|
||
| old_shape = data.shape | ||
| data = data.view(-1) | ||
|
|
||
| origin_len = self.dp_origin_lens[dp_rank] | ||
| assert data.shape[0] % origin_len == 0 | ||
| scale_size = data.shape[0] // origin_len | ||
| handle_len = self.dp_handle_lens[dp_rank] | ||
|
|
||
| dest_data = g_cache_manager.alloc_tensor( | ||
| shape=(handle_len * scale_size,), | ||
| data_type=data.dtype, | ||
| device="cuda", | ||
| is_graph_out=False, | ||
| microbatch_index=self.microbatch_index, | ||
| ) | ||
| dist.all_to_all_single( | ||
| output=dest_data.view(-1), | ||
| input=data.view(-1), | ||
| output_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]], | ||
| input_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]], | ||
| group=self.dist_group.dp_prefill_balance_group, | ||
| async_op=False, | ||
| ) | ||
| return dest_data.view(-1, *old_shape[1:]) | ||
|
|
||
| def _all_to_all_unbalance_get(self, data: torch.Tensor): | ||
| dp_rank = get_global_dp_rank() | ||
| import torch.distributed as dist | ||
| from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager | ||
|
|
||
| old_shape = data.shape | ||
| data = data.view(-1) | ||
|
|
||
| handle_len = self.dp_handle_lens[dp_rank] | ||
| scale_size = data.shape[0] // handle_len | ||
| assert data.shape[0] % handle_len == 0 | ||
| origin_len = self.dp_origin_lens[dp_rank] | ||
| origin_data = g_cache_manager.alloc_tensor( | ||
| shape=(origin_len * scale_size,), | ||
| data_type=data.dtype, | ||
| device="cuda", | ||
| is_graph_out=False, | ||
| microbatch_index=self.microbatch_index, | ||
| ) | ||
| dist.all_to_all_single( | ||
| output=origin_data.view(-1), | ||
| input=data, | ||
| output_split_sizes=[e * scale_size for e in self.dp_input_split_sizes[dp_rank]], | ||
| input_split_sizes=[e * scale_size for e in self.dp_output_split_sizes[dp_rank]], | ||
| group=self.dist_group.dp_prefill_balance_group, | ||
| async_op=False, | ||
| ) | ||
| return origin_data.view(-1, *old_shape[1:]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -233,6 +233,11 @@ def _tpsp_get_qkv( | |
| infer_state.position_cos, | ||
| infer_state.position_sin, | ||
| ) | ||
|
|
||
| if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance: | ||
| q = infer_state._all_to_all_unbalance_get(data=q) | ||
| cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) | ||
|
||
|
|
||
| return q, cache_kv | ||
|
|
||
| def _context_attention_flashinfer_kernel_fp8( | ||
|
|
@@ -402,10 +407,16 @@ def _get_o( | |
| def _tpsp_get_o( | ||
| self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight | ||
| ) -> torch.Tensor: | ||
| if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance: | ||
| input = infer_state._all_to_all_balance_get(data=input) | ||
|
|
||
| input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) | ||
| dest_size = triton.cdiv(input.shape[0], self.tp_world_size_) * self.tp_world_size_ | ||
| o_tensor = self.alloc_tensor((dest_size, self.embed_dim_), dtype=input.dtype, device=input.device) | ||
| layer_weight.o_proj.mm(input, out=o_tensor[0 : len(infer_state.position_cos), :]) | ||
| e_o_tensor = o_tensor[len(infer_state.position_cos) :, :] | ||
| if e_o_tensor.shape[0] > 0: | ||
| e_o_tensor.fill_(0) | ||
|
|
||
| if self.tp_world_size_ > 1: | ||
| sp_token_num = o_tensor.shape[0] // self.tp_world_size_ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The import
import torch.distributed as distis local to theprefill_dp_balancemethod. It's better to move it to the top of the file for consistency and to avoid repeated import overhead. The same applies to other local imports in_all_to_all_balance_getand_all_to_all_unbalance_get.