Skip to content

Commit 5d6fe53

Browse files
committed
fix
1 parent 5ac8bfa commit 5d6fe53

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

lightllm/common/basemodel/infer_struct.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -167,39 +167,46 @@ def prefill_dp_balance(self, input_ids: torch.Tensor):
167167
for i in range(sum_input_len % args.dp):
168168
dp_handle_lens[i] += 1
169169

170-
self.dp_handle_lens = dp_handle_lens
170+
self.dp_handle_lens = dp_handle_lens.copy()
171+
172+
dest_dp_inputs = [[] for _ in range(args.dp)]
171173
# 分配每个dp 的原始输入和分配后的原始输入
172174
origin_datas = collections.deque()
173175
for origin_dp_index, origin_dp_input_len in enumerate(dp_input_lens.numpy()):
174-
origin_datas.append((origin_dp_index, 0, origin_dp_input_len))
176+
handle_len = dp_handle_lens[origin_dp_index]
177+
if origin_dp_input_len > handle_len:
178+
origin_datas.append((origin_dp_index, handle_len, origin_dp_input_len))
179+
dp_handle_lens[origin_dp_index] = 0
180+
dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, handle_len))
181+
else:
182+
dp_handle_lens[origin_dp_index] -= origin_dp_input_len
183+
dest_dp_inputs[origin_dp_index].append((origin_dp_index, 0, origin_dp_input_len))
175184

176-
dest_dp_inputs = []
177185
for dest_dp_index in range(args.dp):
178-
dest_dp_data = []
179186
need_size = dp_handle_lens[dest_dp_index]
187+
if need_size == 0:
188+
continue
180189
while len(origin_datas) != 0:
181190
origin_data = origin_datas.popleft()
182191
origin_dp_index, start, end = origin_data
183192
if end - start > need_size:
184-
dest_dp_data.append((origin_dp_index, start, start + need_size))
193+
dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, start + need_size))
185194
origin_datas.appendleft((origin_dp_index, start + need_size, end))
186195
break
187196
else:
188-
dest_dp_data.append((origin_dp_index, start, end))
197+
dest_dp_inputs[dest_dp_index].append((origin_dp_index, start, end))
189198
need_size -= end - start
190199
if need_size == 0:
191200
break
192201

193-
dest_dp_inputs.append(dest_dp_data)
194-
195202
dp_output_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
196203
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
197204
for origin_dp_index, start, end in dest_dp_data:
198-
dp_output_split_sizes[dest_dp_index][origin_dp_index] = end - start
205+
dp_output_split_sizes[dest_dp_index][origin_dp_index] += end - start
199206
dp_input_split_sizes = [[0 for _ in range(args.dp)] for _ in range(args.dp)]
200207
for dest_dp_index, dest_dp_data in enumerate(dest_dp_inputs):
201208
for origin_dp_index, start, end in dest_dp_data:
202-
dp_input_split_sizes[origin_dp_index][dest_dp_index] = end - start
209+
dp_input_split_sizes[origin_dp_index][dest_dp_index] += end - start
203210

204211
self.dp_input_split_sizes = dp_input_split_sizes
205212
self.dp_output_split_sizes = dp_output_split_sizes

0 commit comments

Comments
 (0)