Skip to content

Commit d505acb

Browse files
committed
[refactor] Adjust request handling in MTPSampler
Signed-off-by: Robin Kobus <[email protected]>
1 parent d0b9832 commit d505acb

File tree

1 file changed

+17
-18
lines changed
  • tensorrt_llm/_torch/speculative

1 file changed

+17
-18
lines changed

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -267,16 +267,7 @@ def update_requests(
267267
new_tokens_lens_list = state.host.new_tokens_lens.tolist()
268268
next_draft_tokens_list = state.host.next_draft_tokens.tolist()
269269
beam_idx = BEAM
270-
for req in state.scheduled_requests.context_requests:
271-
if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0:
272-
continue
273-
new_token = add_token(req, new_tokens, beam=beam_idx)
274-
TorchSampler._handle_stop_criteria(req,
275-
new_token,
276-
max_seq_len=self.max_seq_len)
277-
self._request_common_handling(req, next_draft_tokens_list)
278-
279-
for req in state.scheduled_requests.generation_requests:
270+
for req in state.requests:
280271
if req.state == LlmRequestState.GENERATION_COMPLETE:
281272
continue
282273
num_new_tokens = new_tokens_lens_list[req.py_seq_slot]
@@ -298,14 +289,22 @@ def sample_async(
298289
# next_draft_tokens_device: predicted draft tokens, device tensor, shape: batch_size, nextn
299290
# next_new_tokens_device: input tokens for the next iteration, device tensor, shape: batch_size, nextn + 1
300291

301-
requests = scheduled_requests.all_requests()
302-
slots = torch.as_tensor([r.py_seq_slot for r in requests])
292+
finished_context_requests = [
293+
req for req in scheduled_requests.context_requests
294+
if req.is_last_context_chunk
295+
]
296+
297+
sampling_requests = finished_context_requests + scheduled_requests.generation_requests
298+
num_sampling_requests = len(sampling_requests)
299+
300+
slots = torch.as_tensor([r.py_seq_slot for r in sampling_requests])
303301
slots = slots.to(device="cuda", non_blocking=True)
304302

305-
o_new_tokens = outputs['new_tokens'][:len(requests)]
306-
o_new_tokens_lens = outputs['new_tokens_lens'][:len(requests)]
307-
o_next_draft_tokens = outputs['next_draft_tokens'][:len(requests)]
308-
o_next_new_tokens = outputs['next_new_tokens'][:len(requests)]
303+
o_new_tokens = outputs['new_tokens'][:num_sampling_requests]
304+
o_new_tokens_lens = outputs['new_tokens_lens'][:num_sampling_requests]
305+
o_next_draft_tokens = outputs[
306+
'next_draft_tokens'][:num_sampling_requests]
307+
o_next_new_tokens = outputs['next_new_tokens'][:num_sampling_requests]
309308

310309
new_tokens = self.store.new_tokens
311310
next_new_tokens = self.store.next_new_tokens
@@ -331,9 +330,9 @@ def sample_async(
331330
sampler_event.record()
332331
# add dummy draft tokens to context requests to prepare kv cache in advance
333332
# with the max draft token length
334-
for request in scheduled_requests.context_requests:
333+
for request in finished_context_requests:
335334
request.py_draft_tokens = [1] * self.draft_len
336-
return SampleStateMTP(scheduled_requests=scheduled_requests,
335+
return SampleStateMTP(requests=sampling_requests,
337336
device=device,
338337
host=host,
339338
sampler_event=sampler_event)

0 commit comments

Comments
 (0)