@@ -525,7 +525,8 @@ def _setup_draft_batch_and_resources(
525525 return draft_batch , req_id_to_old_request
526526
527527 def process_static_draft_outputs (
528- self , outputs : Any , draft_batch : ScheduledRequests ,
528+ self , outputs : torch .Tensor | SampleState ,
529+ draft_batch : ScheduledRequests ,
529530 req_id_to_old_request : Dict [int , LlmRequest ]) -> None :
530531 """
531532 Process outputs from static draft loop, update target requests, and clean up resources.
@@ -535,7 +536,13 @@ def process_static_draft_outputs(
535536 draft_batch: The draft batch that was processed
536537 req_id_to_old_request: Mapping from draft request ID to original request
537538 """
538- outputs_host = outputs .cpu ()
539+ if isinstance (outputs , torch .Tensor ):
540+ # For non-overlap scheduler path.
541+ outputs_host = outputs .cpu ()
542+ else :
543+ outputs_host = outputs .host .new_tokens
544+ outputs .sampler_event .synchronize ()
545+
539546 for token_idx in range (self .max_draft_tokens ):
540547 for req_idx , req in enumerate (draft_batch .all_requests ()):
541548 target_model_req = req_id_to_old_request [req .py_request_id ]
@@ -703,6 +710,17 @@ def generate_draft_tokens_with_overlap(
703710 draft_length = self .max_draft_tokens ,
704711 draft_batch = draft_batch ,
705712 req_id_to_old_request = req_id_to_old_request )
713+
714+ new_tokens_host = outputs .to (device = 'cpu' , non_blocking = True )
715+ sampler_event = torch .cuda .Event ()
716+ sampler_event .record ()
717+
718+ outputs = SampleState (
719+ scheduled_requests = draft_batch ,
720+ device = SampleStateTensors (new_tokens = outputs ),
721+ host = SampleStateTensors (new_tokens = new_tokens_host ),
722+ sampler_event = sampler_event )
723+
706724 return target_inputs , outputs , draft_batch
707725
708726 # Handle guided decoder and sampling for non-static loop
0 commit comments