@@ -93,14 +93,20 @@ def values(self):
9393 return vars (self ).values ()
9494
9595
96+ @dataclass (kw_only = True )
97+ class SamplerEvent :
98+ cuda_event : torch .cuda .Event
99+ worker_futures : Optional [list [futures .Future [Any ]]] = None
100+
101+
96102@dataclass (kw_only = True )
97103class SampleState :
98104 scheduled_requests : ScheduledRequests
99105
100106 device : Optional [SampleStateTensors ] = None
101107 host : Optional [SampleStateTensors ] = None
102108
103- sampler_event : Optional [torch . cuda . Event ] = None
109+ sampler_event : Optional [SamplerEvent ] = None
104110
105111
106112class Sampler (ABC ):
@@ -595,12 +601,15 @@ class AsyncWorkerMixin:
595601 operations will seamlessly run on the main thread
596602 """
597603
604+ MAX_WORKERS = 1
605+
598606 def _async_worker_active (self ) -> bool :
599607 return self ._async_worker is not None
600608
601609 def _async_worker_init (self , enable_async_worker : bool ):
602610 self .enable_async_worker = enable_async_worker
603611 self ._async_worker = None
612+ self ._async_worker_futures : list [futures .Future [any ]] = []
604613
605614 def async_worker_start (self ):
606615 assert self .enable_async_worker
@@ -615,7 +624,7 @@ def _async_worker_initializer(device_id):
615624 torch .cuda .set_stream (torch .cuda .Stream ())
616625
617626 self ._async_worker = futures .ThreadPoolExecutor (
618- max_workers = 1 ,
627+ max_workers = self . MAX_WORKERS ,
619628 initializer = _async_worker_initializer ,
620629 initargs = (torch .cuda .current_device (),),
621630 )
@@ -633,41 +642,54 @@ def _async_worker_run(self, ready: torch.cuda.Event, func, /, *args, **kwargs):
633642 ready .synchronize ()
634643
635644 # Do the work
636- return func (* args , ** kwargs )
645+ result = func (* args , ** kwargs )
646+
647+ # Work submitted to the async worker is expected to block at the end,
648+ # consistent with the semantics of futures; make sure that we wait for
649+ # everything to complete
650+ torch .cuda .current_stream ().synchronize ()
651+
652+ return result
637653
638654 def _async_worker_submit (self , func , / , * args , ** kwargs ):
639655 if self ._async_worker_active ():
640656 # Record an event on the main thread/stream that we will
641657 # synchronize with on the worker thread/stream
642658 ready = torch .cuda .Event ()
643659 ready .record ()
644- return self ._async_worker .submit (self ._async_worker_run , ready , func , * args , ** kwargs )
660+
661+ # Submit the async work
662+ result = self ._async_worker .submit (self ._async_worker_run , ready , func , * args , ** kwargs )
663+
664+ # Save the future, so that we can await it later
665+ self ._async_worker_futures .append (result )
666+
667+ return result
645668 else :
646669 # If the async worker is not in use, just execute the function
647670 return func (* args , ** kwargs )
648671
649- def _copy_to_host (self , src : torch .Tensor , pin_memory = False ) -> torch .Tensor :
650- dest = torch .empty_like (src , device = "cpu" , pin_memory = pin_memory )
672+ def _copy_to_host (self , src : torch .Tensor ) -> torch .Tensor :
673+ dest = torch .empty_like (src , device = "cpu" , pin_memory = True )
651674 self ._async_worker_submit (dest .copy_ , src , non_blocking = True )
652675 return dest
653676
654- def _sampler_event_get (self ) -> torch .cuda .Event | futures .Future [torch .cuda .Event ]:
655- def _get_sampler_event () -> torch .cuda .Event :
656- sampler_event = torch .cuda .Event ()
657- sampler_event .record ()
658- return sampler_event
677+ def _sampler_event_get (self ) -> SamplerEvent :
678+ cuda_event = torch .cuda .Event ()
679+ cuda_event .record ()
659680
660- return self ._async_worker_submit (_get_sampler_event )
681+ # Transfer ownership to worker_futures and re-initialize
682+ worker_futures = self ._async_worker_futures
683+ self ._async_worker_futures = []
684+
685+ return SamplerEvent (cuda_event = cuda_event , worker_futures = worker_futures )
661686
662687 @staticmethod
663- def _sampler_event_synchronize (
664- sampler_event : torch .cuda .Event | futures .Future [torch .cuda .Event ] | None ,
665- ):
688+ def _sampler_event_synchronize (sampler_event : SamplerEvent ):
666689 if sampler_event :
667- if isinstance (sampler_event , futures .Future ):
668- sampler_event .result ().synchronize ()
669- else :
670- sampler_event .synchronize ()
690+ if sampler_event .worker_futures :
691+ futures .wait (sampler_event .worker_futures )
692+ sampler_event .cuda_event .synchronize ()
671693
672694
673695class TorchSampler (Sampler , AsyncWorkerMixin ):
@@ -1229,7 +1251,7 @@ def sample_async(
12291251 self ._write_finish_reasons (
12301252 requests , finish_reasons = finish_reasons , seq_slots = seq_slots , new_tokens = new_tokens
12311253 )
1232- finish_reasons_host = finish_reasons . to ( device = "cpu" , non_blocking = True )
1254+ finish_reasons_host = self . _copy_to_host ( finish_reasons )
12331255
12341256 sampler_event = self ._sampler_event_get ()
12351257 return SampleStateTorch (
@@ -1852,8 +1874,8 @@ def _process_requests(
18521874 logprobs_cuda , k = max (req .py_num_logprobs for req in requests ), dim = - 1
18531875 )
18541876 # Use a single D2H copy to reduce overheads
1855- topk_vals = self ._copy_to_host (topk_vals_cuda , pin_memory = True )
1856- topk_indices = self ._copy_to_host (topk_indices_cuda , pin_memory = True )
1877+ topk_vals = self ._copy_to_host (topk_vals_cuda )
1878+ topk_indices = self ._copy_to_host (topk_indices_cuda )
18571879 current_offset = 0
18581880 for req_id , steps in zip (
18591881 logprobs_req_indices , req_num_steps [logprobs_req_indices ].tolist ()
0 commit comments