Skip to content

Commit d708701

Browse files
authored
[TRTLLM-8271][fix] Fix CDL overlap scheduling performance (#7971)
Signed-off-by: Mike Iovine <[email protected]>
1 parent c8bef27 commit d708701

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,8 @@ def _setup_draft_batch_and_resources(
525525
return draft_batch, req_id_to_old_request
526526

527527
def process_static_draft_outputs(
528-
self, outputs: Any, draft_batch: ScheduledRequests,
528+
self, outputs: torch.Tensor | SampleState,
529+
draft_batch: ScheduledRequests,
529530
req_id_to_old_request: Dict[int, LlmRequest]) -> None:
530531
"""
531532
Process outputs from static draft loop, update target requests, and clean up resources.
@@ -535,7 +536,13 @@ def process_static_draft_outputs(
535536
draft_batch: The draft batch that was processed
536537
req_id_to_old_request: Mapping from draft request ID to original request
537538
"""
538-
outputs_host = outputs.cpu()
539+
if isinstance(outputs, torch.Tensor):
540+
# For non-overlap scheduler path.
541+
outputs_host = outputs.cpu()
542+
else:
543+
outputs_host = outputs.host.new_tokens
544+
outputs.sampler_event.synchronize()
545+
539546
for token_idx in range(self.max_draft_tokens):
540547
for req_idx, req in enumerate(draft_batch.all_requests()):
541548
target_model_req = req_id_to_old_request[req.py_request_id]
@@ -703,6 +710,17 @@ def generate_draft_tokens_with_overlap(
703710
draft_length=self.max_draft_tokens,
704711
draft_batch=draft_batch,
705712
req_id_to_old_request=req_id_to_old_request)
713+
714+
new_tokens_host = outputs.to(device='cpu', non_blocking=True)
715+
sampler_event = torch.cuda.Event()
716+
sampler_event.record()
717+
718+
outputs = SampleState(
719+
scheduled_requests=draft_batch,
720+
device=SampleStateTensors(new_tokens=outputs),
721+
host=SampleStateTensors(new_tokens=new_tokens_host),
722+
sampler_event=sampler_event)
723+
706724
return target_inputs, outputs, draft_batch
707725

708726
# Handle guided decoder and sampling for non-static loop

0 commit comments

Comments
 (0)