@@ -267,16 +267,7 @@ def update_requests(
267267 new_tokens_lens_list = state .host .new_tokens_lens .tolist ()
268268 next_draft_tokens_list = state .host .next_draft_tokens .tolist ()
269269 beam_idx = BEAM
270- for req in state .scheduled_requests .context_requests :
271- if req .state == LlmRequestState .GENERATION_COMPLETE or req .context_remaining_length != 0 :
272- continue
273- new_token = add_token (req , new_tokens , beam = beam_idx )
274- TorchSampler ._handle_stop_criteria (req ,
275- new_token ,
276- max_seq_len = self .max_seq_len )
277- self ._request_common_handling (req , next_draft_tokens_list )
278-
279- for req in state .scheduled_requests .generation_requests :
270+ for req in state .requests :
280271 if req .state == LlmRequestState .GENERATION_COMPLETE :
281272 continue
282273 num_new_tokens = new_tokens_lens_list [req .py_seq_slot ]
@@ -298,14 +289,22 @@ def sample_async(
298289 # next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn
299290 # next_new_tokens_device: input tokens for the next iteration, device tensor, shape: batch_size, nextn + 1
300291
301- requests = scheduled_requests .all_requests ()
302- slots = torch .as_tensor ([r .py_seq_slot for r in requests ])
292+ finished_context_requests = [
293+ req for req in scheduled_requests .context_requests
294+ if req .is_last_context_chunk
295+ ]
296+
297+ sampling_requests = finished_context_requests + scheduled_requests .generation_requests
298+ num_sampling_requests = len (sampling_requests )
299+
300+ slots = torch .as_tensor ([r .py_seq_slot for r in sampling_requests ])
303301 slots = slots .to (device = "cuda" , non_blocking = True )
304302
305- o_new_tokens = outputs ['new_tokens' ][:len (requests )]
306- o_new_tokens_lens = outputs ['new_tokens_lens' ][:len (requests )]
307- o_next_draft_tokens = outputs ['next_draft_tokens' ][:len (requests )]
308- o_next_new_tokens = outputs ['next_new_tokens' ][:len (requests )]
303+ o_new_tokens = outputs ['new_tokens' ][:num_sampling_requests ]
304+ o_new_tokens_lens = outputs ['new_tokens_lens' ][:num_sampling_requests ]
305+ o_next_draft_tokens = outputs [
306+ 'next_draft_tokens' ][:num_sampling_requests ]
307+ o_next_new_tokens = outputs ['next_new_tokens' ][:num_sampling_requests ]
309308
310309 new_tokens = self .store .new_tokens
311310 next_new_tokens = self .store .next_new_tokens
@@ -331,9 +330,9 @@ def sample_async(
331330 sampler_event .record ()
332331 # add dummy draft tokens to context requests to prepare kv cache in advance
333332 # with the max draft token length
334- for request in scheduled_requests . context_requests :
333+ for request in finished_context_requests :
335334 request .py_draft_tokens = [1 ] * self .draft_len
336- return SampleStateMTP (scheduled_requests = scheduled_requests ,
335+ return SampleStateMTP (requests = sampling_requests ,
337336 device = device ,
338337 host = host ,
339338 sampler_event = sampler_event )
0 commit comments