@@ -91,14 +91,20 @@ def values(self):
9191 return vars (self ).values ()
9292
9393
94+ @dataclass (kw_only = True )
95+ class SamplerEvent :
96+ cuda_event : torch .cuda .Event
97+ worker_futures : Optional [list [futures .Future [Any ]]] = None
98+
99+
94100@dataclass (kw_only = True )
95101class SampleState :
96102 scheduled_requests : ScheduledRequests
97103
98104 device : Optional [SampleStateTensors ] = None
99105 host : Optional [SampleStateTensors ] = None
100106
101- sampler_event : Optional [torch . cuda . Event ] = None
107+ sampler_event : Optional [SamplerEvent ] = None
102108
103109
104110class Sampler (ABC ):
@@ -593,12 +599,15 @@ class AsyncWorkerMixin:
593599 operations will seamlessly run on the main thread
594600 """
595601
602+ MAX_WORKERS = 1
603+
596604 def _async_worker_active (self ) -> bool :
597605 return self ._async_worker is not None
598606
599607 def _async_worker_init (self , enable_async_worker : bool ):
600608 self .enable_async_worker = enable_async_worker
601609 self ._async_worker = None
610+ self ._async_worker_futures : list [futures .Future [any ]] = []
602611
603612 def async_worker_start (self ):
604613 assert self .enable_async_worker
@@ -613,7 +622,7 @@ def _async_worker_initializer(device_id):
613622 torch .cuda .set_stream (torch .cuda .Stream ())
614623
615624 self ._async_worker = futures .ThreadPoolExecutor (
616- max_workers = 1 ,
625+ max_workers = self . MAX_WORKERS ,
617626 initializer = _async_worker_initializer ,
618627 initargs = (torch .cuda .current_device (),),
619628 )
@@ -631,41 +640,54 @@ def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs):
631640 ready .synchronize ()
632641
633642 # Do the work
634- return func (* args , ** kwargs )
643+ result = func (* args , ** kwargs )
644+
645+ # Work submitted to the async worker is expected to block at the end,
646+ # consistent with the semantics of futures; make sure that we wait for
647+ # everything to complete
648+ torch .cuda .current_stream ().synchronize ()
649+
650+ return result
635651
636652 def _async_worker_submit (self , func , / , * args , ** kwargs ):
637653 if self ._async_worker_active ():
638654 # Record an event on the main thread/stream that we will
639655 # synchronize with on the worker thread/stream
640656 ready = torch .cuda .Event ()
641657 ready .record ()
642- return self ._async_worker .submit (self ._async_worker_run , ready , func , * args , ** kwargs )
658+
659+ # Submit the async work
660+ result = self ._async_worker .submit (self ._async_worker_run , ready , func , * args , ** kwargs )
661+
662+ # Save the future, so that we can await it later
663+ self ._async_worker_futures .append (result )
664+
665+ return result
643666 else :
644667 # If the async worker is not in use, just execute the function
645668 return func (* args , ** kwargs )
646669
647- def _copy_to_host (self , src : torch .Tensor , pin_memory = False ) -> torch .Tensor :
648- dest = torch .empty_like (src , device = "cpu" , pin_memory = pin_memory )
670+ def _copy_to_host (self , src : torch .Tensor ) -> torch .Tensor :
671+ dest = torch .empty_like (src , device = "cpu" , pin_memory = True )
649672 self ._async_worker_submit (dest .copy_ , src , non_blocking = True )
650673 return dest
651674
652- def _sampler_event_get (self ) -> torch .cuda .Event | futures .Future [torch .cuda .Event ]:
653- def _get_sampler_event () -> torch .cuda .Event :
654- sampler_event = torch .cuda .Event ()
655- sampler_event .record ()
656- return sampler_event
675+ def _sampler_event_get (self ) -> SamplerEvent :
676+ cuda_event = torch .cuda .Event ()
677+ cuda_event .record ()
657678
658- return self ._async_worker_submit (_get_sampler_event )
679+ # Transfer ownership to worker_futures and re-initialize
680+ worker_futures = self ._async_worker_futures
681+ self ._async_worker_futures = []
682+
683+ return SamplerEvent (cuda_event = cuda_event , worker_futures = worker_futures )
659684
660685 @staticmethod
661- def _sampler_event_synchronize (
662- sampler_event : torch .cuda .Event | futures .Future [torch .cuda .Event ] | None ,
663- ):
686+ def _sampler_event_synchronize (sampler_event : SamplerEvent ):
664687 if sampler_event :
665- if isinstance (sampler_event , futures .Future ):
666- sampler_event .result ().synchronize ()
667- else :
668- sampler_event .synchronize ()
688+ if sampler_event .worker_futures :
689+ futures .wait (sampler_event .worker_futures )
690+ sampler_event .cuda_event .synchronize ()
669691
670692
671693class TorchSampler (Sampler , AsyncWorkerMixin ):
@@ -1216,7 +1238,7 @@ def sample_async(
12161238 self ._write_finish_reasons (
12171239 requests , finish_reasons = finish_reasons , seq_slots = seq_slots , new_tokens = new_tokens
12181240 )
1219- finish_reasons_host = finish_reasons . to ( device = "cpu" , non_blocking = True )
1241+ finish_reasons_host = self . _copy_to_host ( finish_reasons )
12201242
12211243 sampler_event = self ._sampler_event_get ()
12221244 return SampleStateTorch (
@@ -1839,8 +1861,8 @@ def _process_requests(
18391861 logprobs_cuda , k = max (req .py_num_logprobs for req in requests ), dim = - 1
18401862 )
18411863 # Use a single D2H copy to reduce overheads
1842- topk_vals = self ._copy_to_host (topk_vals_cuda , pin_memory = True )
1843- topk_indices = self ._copy_to_host (topk_indices_cuda , pin_memory = True )
1864+ topk_vals = self ._copy_to_host (topk_vals_cuda )
1865+ topk_indices = self ._copy_to_host (topk_indices_cuda )
18441866 current_offset = 0
18451867 for req_id , steps in zip (
18461868 logprobs_req_indices , req_num_steps [logprobs_req_indices ].tolist ()
0 commit comments