@@ -167,39 +167,46 @@ def prefill_dp_balance(self, input_ids: torch.Tensor):
167167 for i in range (sum_input_len % args .dp ):
168168 dp_handle_lens [i ] += 1
169169
170- self .dp_handle_lens = dp_handle_lens
170+ self .dp_handle_lens = dp_handle_lens .copy ()
171+
172+ dest_dp_inputs = [[] for _ in range (args .dp )]
171173 # 分配每个dp 的原始输入和分配后的原始输入
172174 origin_datas = collections .deque ()
173175 for origin_dp_index , origin_dp_input_len in enumerate (dp_input_lens .numpy ()):
174- origin_datas .append ((origin_dp_index , 0 , origin_dp_input_len ))
176+ handle_len = dp_handle_lens [origin_dp_index ]
177+ if origin_dp_input_len > handle_len :
178+ origin_datas .append ((origin_dp_index , handle_len , origin_dp_input_len ))
179+ dp_handle_lens [origin_dp_index ] = 0
180+ dest_dp_inputs [origin_dp_index ].append ((origin_dp_index , 0 , handle_len ))
181+ else :
182+ dp_handle_lens [origin_dp_index ] -= origin_dp_input_len
183+ dest_dp_inputs [origin_dp_index ].append ((origin_dp_index , 0 , origin_dp_input_len ))
175184
176- dest_dp_inputs = []
177185 for dest_dp_index in range (args .dp ):
178- dest_dp_data = []
179186 need_size = dp_handle_lens [dest_dp_index ]
187+ if need_size == 0 :
188+ continue
180189 while len (origin_datas ) != 0 :
181190 origin_data = origin_datas .popleft ()
182191 origin_dp_index , start , end = origin_data
183192 if end - start > need_size :
184- dest_dp_data .append ((origin_dp_index , start , start + need_size ))
193+ dest_dp_inputs [ dest_dp_index ] .append ((origin_dp_index , start , start + need_size ))
185194 origin_datas .appendleft ((origin_dp_index , start + need_size , end ))
186195 break
187196 else :
188- dest_dp_data .append ((origin_dp_index , start , end ))
197+ dest_dp_inputs [ dest_dp_index ] .append ((origin_dp_index , start , end ))
189198 need_size -= end - start
190199 if need_size == 0 :
191200 break
192201
193- dest_dp_inputs .append (dest_dp_data )
194-
195202 dp_output_split_sizes = [[0 for _ in range (args .dp )] for _ in range (args .dp )]
196203 for dest_dp_index , dest_dp_data in enumerate (dest_dp_inputs ):
197204 for origin_dp_index , start , end in dest_dp_data :
198- dp_output_split_sizes [dest_dp_index ][origin_dp_index ] = end - start
205+ dp_output_split_sizes [dest_dp_index ][origin_dp_index ] + = end - start
199206 dp_input_split_sizes = [[0 for _ in range (args .dp )] for _ in range (args .dp )]
200207 for dest_dp_index , dest_dp_data in enumerate (dest_dp_inputs ):
201208 for origin_dp_index , start , end in dest_dp_data :
202- dp_input_split_sizes [origin_dp_index ][dest_dp_index ] = end - start
209+ dp_input_split_sizes [origin_dp_index ][dest_dp_index ] + = end - start
203210
204211 self .dp_input_split_sizes = dp_input_split_sizes
205212 self .dp_output_split_sizes = dp_output_split_sizes
0 commit comments