Skip to content

Commit 5ac8bfa

Browse files
committed
fix
1 parent b99deda commit 5ac8bfa

File tree

5 files changed

+9
-7
lines changed

5 files changed

+9
-7
lines changed

lightllm/common/basemodel/infer_struct.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def prefill_dp_balance(self, input_ids: torch.Tensor):
147147
assert self.is_prefill
148148
import torch.distributed as dist
149149

150+
self.need_dp_prefill_balance = True
151+
150152
args = get_env_start_args()
151153

152154
dp_input_lens = torch.empty(size=(args.dp,), device="cuda", dtype=torch.int32)

lightllm/distributed/communication_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self):
6363
self.custom_gather = None
6464
self.dp_world_size = get_dp_world_size()
6565
self.device_group = create_new_group_for_current_dp("nccl")
66-
if get_env_start_args().dp > 1 and get_env_start_args().enable_dp_prefill_balance:
66+
if get_env_start_args().enable_dp_prefill_balance:
6767
self.dp_prefill_balance_group = create_dp_special_inter_group("nccl")
6868
else:
6969
self.dp_prefill_balance_group = None

lightllm/models/llama/layer_infer/post_layer_infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def tpsp_token_forward(
117117
# len(infer_state.position_sin) 获取真实输入长度
118118
input_embdings = gather_data[0 : len(infer_state.position_sin)]
119119

120-
if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
120+
if infer_state.need_dp_prefill_balance:
121121
input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings)
122122

123123
return self.token_forward(input_embdings=input_embdings, infer_state=infer_state, layer_weight=layer_weight)
@@ -134,7 +134,7 @@ def overlap_tpsp_token_forward(
134134
infer_state.hook()
135135
infer_state.hook = None
136136

137-
if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
137+
if infer_state.need_dp_prefill_balance:
138138
input_embdings = infer_state._all_to_all_unbalance_get(data=input_embdings)
139139

140140
logics = self.tpsp_token_forward(input_embdings, infer_state, layer_weight=layer_weight)
@@ -143,7 +143,7 @@ def overlap_tpsp_token_forward(
143143
infer_state1.hook()
144144
infer_state1.hook = None
145145

146-
if infer_state1.is_prefill and get_env_start_args().enable_dp_prefill_balance:
146+
if infer_state1.need_dp_prefill_balance:
147147
input_embdings1 = infer_state1._all_to_all_unbalance_get(data=input_embdings1)
148148

149149
logics1 = self.tpsp_token_forward(input_embdings1, infer_state1, layer_weight=layer_weight)

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def _tpsp_get_qkv(
228228
infer_state.position_sin,
229229
)
230230

231-
if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
231+
if infer_state.need_dp_prefill_balance:
232232
q = infer_state._all_to_all_unbalance_get(data=q)
233233
cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv)
234234

@@ -401,7 +401,7 @@ def _get_o(
401401
def _tpsp_get_o(
402402
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
403403
) -> torch.Tensor:
404-
if infer_state.is_prefill and get_env_start_args().enable_dp_prefill_balance:
404+
if infer_state.need_dp_prefill_balance:
405405
input = infer_state._all_to_all_balance_get(data=input)
406406

407407
input = input.view(-1, self.tp_o_head_num_ * self.head_dim_)

lightllm/server/api_start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def normal_or_p_d_start(args):
140140
assert args.router_token_ratio == 0.0
141141

142142
if args.enable_dp_prefill_balance:
143-
assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly"
143+
assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1"
144144

145145
# mtp params check
146146
if args.mtp_mode is not None:

0 commit comments

Comments
 (0)