8484
8585@dataclass (kw_only = True )
8686class SampleStateTensors :
87- new_tokens : torch .Tensor | futures . Future [ torch . Tensor ]
87+ new_tokens : torch .Tensor
8888 log_probs : torch .Tensor | None = None
8989
9090 def values (self ):
@@ -623,22 +623,15 @@ def async_worker_stop(self):
623623 self ._async_worker .shutdown (wait = True )
624624 self ._async_worker = None
625625
626- def _async_worker_run (self , ready , func , / , * args , ** kwargs ):
626+ def _async_worker_run (self , ready : torch . cuda . Event , func , / , * args , ** kwargs ):
627627 # Make sure the async work takes place after all prior operations on
628628 # the primary stream. synchronize() is intentionally chosen instead of
629629 # wait() here; otherwise, blocking copies will stall subsequent CUDA
630630 # API calls on the main thread
631631 ready .synchronize ()
632632
633633 # Do the work
634- result = func (* args , ** kwargs )
635-
636- # work submitted to the async worker is expected to block at the end,
637- # consistent with the semantics of futures; make sure that we wait for
638- # everything to complete
639- torch .cuda .current_stream ().synchronize ()
640-
641- return result
634+ return func (* args , ** kwargs )
642635
643636 def _async_worker_submit (self , func , / , * args , ** kwargs ):
644637 if self ._async_worker_active ():
@@ -651,6 +644,29 @@ def _async_worker_submit(self, func, /, *args, **kwargs):
651644 # If the async worker is not in use, just execute the function
652645 return func (* args , ** kwargs )
653646
647+ def _copy_to_host (self , src : torch .Tensor , pin_memory = False ) -> torch .Tensor :
648+ dest = torch .empty_like (src , device = "cpu" , pin_memory = pin_memory )
649+ self ._async_worker_submit (dest .copy_ , src , non_blocking = True )
650+ return dest
651+
652+ def _sampler_event_get (self ) -> torch .cuda .Event | futures .Future [torch .cuda .Event ]:
653+ def _get_sampler_event () -> torch .cuda .Event :
654+ sampler_event = torch .cuda .Event ()
655+ sampler_event .record ()
656+ return sampler_event
657+
658+ return self ._async_worker_submit (_get_sampler_event )
659+
660+ @staticmethod
661+ def _sampler_event_synchronize (
662+ sampler_event : torch .cuda .Event | futures .Future [torch .cuda .Event ] | None ,
663+ ):
664+ if sampler_event :
665+ if isinstance (sampler_event , futures .Future ):
666+ sampler_event .result ().synchronize ()
667+ else :
668+ sampler_event .synchronize ()
669+
654670
655671class TorchSampler (Sampler ):
656672 SampleState = SampleStateTorch
@@ -675,7 +691,7 @@ class Args:
675691 max_num_sequences : int
676692 max_beam_width : int
677693 max_total_draft_tokens : int
678- enable_async_worker : Optional [ bool ] = False
694+ enable_async_worker : bool = False
679695
680696 def __init__ (self , args : Args ):
681697 self .max_seq_len = args .max_seq_len
@@ -797,17 +813,8 @@ def handle_logprobs(
797813 count : int ,
798814 ):
799815 if request .py_return_log_probs :
800- if self ._async_worker_active ():
801- # These should be futures if we used the async worker
802- assert isinstance (request .py_topk_logprobs_values , futures .Future )
803- assert isinstance (request .py_topk_logprobs_vals , futures .Future )
804- topk_log_probs_vals = request .py_topk_logprobs_vals .result ()
805- topk_log_probs_indices = request .py_topk_logprobs_indices .result ()
806- else :
807- topk_log_probs_vals = request .py_topk_logprobs_vals
808- topk_log_probs_indices = request .py_topk_logprobs_indices
809- topk_log_probs_vals = topk_log_probs_vals [:count ]
810- topk_log_probs_indices = topk_log_probs_indices [:count ]
816+ topk_log_probs_vals = request .py_topk_logprobs_vals [:count ]
817+ topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
811818
812819 token_log_probs = [
813820 {
@@ -1011,9 +1018,7 @@ def _tree_sampling_batch(
10111018 new_draft_tokens_cuda .transpose (0 , 1 ).to (torch .int , non_blocking = True ).unsqueeze (dim = - 1 )
10121019 )
10131020
1014- new_draft_tokens_host = self ._async_worker_submit (
1015- int_new_draft_tokens .to , "cpu" , non_blocking = True
1016- )
1021+ new_draft_tokens_host = self ._copy_to_host (int_new_draft_tokens )
10171022
10181023 return new_draft_tokens_host
10191024
@@ -1130,16 +1135,10 @@ def update_requests(
11301135 resource_manager : Optional [ResourceManager ] = None ,
11311136 ) -> None :
11321137 assert isinstance (state , SampleStateTorch )
1133- if state .sampler_event :
1134- state .sampler_event .synchronize ()
1138+ self ._sampler_event_synchronize (state .sampler_event )
11351139
11361140 assert state .host is not None
1137-
1138- if self ._async_worker_active ():
1139- assert isinstance (state .host .new_tokens , futures .Future )
1140- new_tokens = state .host .new_tokens .result ()
1141- else :
1142- new_tokens = state .host .new_tokens
1141+ new_tokens = state .host .new_tokens
11431142 finish_reasons = state .host .finish_reasons_list ()
11441143
11451144 new_tokens_list = new_tokens .tolist ()
@@ -1219,8 +1218,7 @@ def sample_async(
12191218 )
12201219 finish_reasons_host = finish_reasons .to (device = "cpu" , non_blocking = True )
12211220
1222- sampler_event = torch .cuda .Event ()
1223- sampler_event .record ()
1221+ sampler_event = self ._sampler_event_get ()
12241222 return SampleStateTorch (
12251223 scheduled_requests = scheduled_requests ,
12261224 device = SampleStateTensors (new_tokens = new_tokens ),
@@ -1480,7 +1478,7 @@ def _dims_canonically_ordered(t: torch.Tensor) -> bool:
14801478 new_tokens_cuda .view (- 1 , * new_tokens_cuda .shape [2 :])[:, beam , ...].scatter_ (
14811479 0 , batch_dest_indices_1d_cuda , batch_next_tokens_cuda_int
14821480 )
1483- new_tokens_host = self ._async_worker_submit (new_tokens_cuda . to , "cpu" , non_blocking = True )
1481+ new_tokens_host = self ._copy_to_host (new_tokens_cuda )
14841482
14851483 return new_tokens_host
14861484
@@ -1840,39 +1838,23 @@ def _process_requests(
18401838 topk_vals_cuda , topk_indices_cuda = torch .topk (
18411839 logprobs_cuda , k = max (req .py_num_logprobs for req in requests ), dim = - 1
18421840 )
1843-
1844- def _copy_log_probs (
1845- requests , req_num_steps , logprobs_req_indices , topk_vals_cuda , topk_indices_cuda
1841+ # Use a single D2H copy to reduce overheads
1842+ topk_vals = self ._copy_to_host (topk_vals_cuda , pin_memory = True )
1843+ topk_indices = self ._copy_to_host (topk_indices_cuda , pin_memory = True )
1844+ current_offset = 0
1845+ for req_id , steps in zip (
1846+ logprobs_req_indices , req_num_steps [logprobs_req_indices ].tolist ()
18461847 ):
1847- # Use a single D2H copy to reduce overheads
1848- topk_vals = torch .empty_like (topk_vals_cuda , device = "cpu" , pin_memory = True )
1849- topk_indices = torch .empty_like (topk_indices_cuda , device = "cpu" , pin_memory = True )
1850- topk_vals .copy_ (topk_vals_cuda , non_blocking = True )
1851- topk_indices .copy_ (topk_indices_cuda , non_blocking = True )
1852- current_offset = 0
1853- for req_id , steps in zip (
1854- logprobs_req_indices , req_num_steps [logprobs_req_indices ].tolist ()
1855- ):
1856- req = requests [req_id ]
1857- next_offset = current_offset + steps
1858- # NB: Assigning views on memory which is being filled
1859- # asynchronously
1860- req .py_topk_logprobs_vals = topk_vals [
1861- current_offset :next_offset , : req .py_num_logprobs
1862- ]
1863- req .py_topk_logprobs_indices = topk_indices [
1864- current_offset :next_offset , : req .py_num_logprobs
1865- ]
1866- current_offset = next_offset
1867-
1868- self ._async_worker_submit (
1869- _copy_log_probs ,
1870- requests ,
1871- req_num_steps ,
1872- logprobs_req_indices ,
1873- topk_vals_cuda ,
1874- topk_indices_cuda ,
1875- )
1848+ req = requests [req_id ]
1849+ next_offset = current_offset + steps
1850+ # NB: Assigning views on memory which is being filled asynchronously
1851+ req .py_topk_logprobs_vals = topk_vals [
1852+ current_offset :next_offset , : req .py_num_logprobs
1853+ ]
1854+ req .py_topk_logprobs_indices = topk_indices [
1855+ current_offset :next_offset , : req .py_num_logprobs
1856+ ]
1857+ current_offset = next_offset
18761858
18771859 # Perform sampling in batches
18781860 batched_sampling_result = self ._sample_batched_by_strategy (
@@ -1934,9 +1916,7 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
19341916class SampleStateTRTLLM (SampleState ):
19351917 finalize_events : dict [str , CudaEvent ] | None = None
19361918 """`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`"""
1937- host : Optional [SampleStateTensorsHostTRTLLM | futures .Future [SampleStateTensorsHostTRTLLM ]] = (
1938- None
1939- )
1919+ host : Optional [SampleStateTensorsHostTRTLLM ] = None
19401920
19411921
19421922class TRTLLMSampler (Sampler , AsyncWorkerMixin ):
@@ -1959,7 +1939,7 @@ def __init__(
19591939 max_beam_width : int ,
19601940 decoding_config : Optional [DecodingConfig ] = None ,
19611941 kv_cache_config : Optional [KvCacheConfig ] = None ,
1962- enable_async_worker : Optional [ bool ] = False ,
1942+ enable_async_worker : bool = False ,
19631943 ):
19641944 vocab_size = model .config .vocab_size
19651945 num_hidden_layers = model .config .num_hidden_layers
@@ -2158,10 +2138,9 @@ def sample_async(
21582138 )
21592139
21602140 finalize_events = {}
2161- gather_ids = False
2162- decoder_state = self .store ["decoder_state" ]
2141+ gathered_ids = None
21632142 if beam_width > 1 :
2164- finished_sum_device = decoder_state .finished_sum
2143+ finished_sum_device = self . store [ " decoder_state" ] .finished_sum
21652144
21662145 for request in scheduled_requests .all_requests ():
21672146 if request .is_context_init_state :
@@ -2170,41 +2149,31 @@ def sample_async(
21702149 finalize_events [request .request_id ] = self ._finalize_request (request , False )
21712150 elif request .streaming :
21722151 finalize_events [request .request_id ] = self ._finalize_request (request , True )
2173- gather_ids = True
2174-
2175- device = SampleStateTensors (new_tokens = decoder_state .all_new_tokens )
2176-
2177- def _copy_tensors_to_host (gather_ids , scheduled_requests , decoder_state ):
2178- gathered_ids = None
2179- if gather_ids :
2180- gathered_ids = decoder_state .gathered_ids .to ("cpu" , non_blocking = True )
2181- new_output_tokens = decoder_state .all_new_tokens .to ("cpu" , non_blocking = True )
2182- finished_sum = decoder_state .finished_sum .to ("cpu" , non_blocking = True )
2183- finish_reasons = decoder_state .finish_reasons .to ("cpu" , non_blocking = True )
2184- sequence_lengths = decoder_state .sequence_lengths .to ("cpu" , non_blocking = True )
2185-
2186- log_probs = None
2187- cum_log_probs = None
2188- if any (request .py_return_log_probs for request in scheduled_requests .all_requests ()):
2189- log_probs = decoder_state .log_probs .to ("cpu" , non_blocking = True )
2190- cum_log_probs = decoder_state .cum_log_probs .to ("cpu" , non_blocking = True )
2191-
2192- return SampleStateTensorsHostTRTLLM (
2193- new_tokens = new_output_tokens ,
2194- finished_sum = finished_sum ,
2195- finish_reasons = finish_reasons ,
2196- sequence_lengths = sequence_lengths ,
2197- log_probs = log_probs ,
2198- cum_log_probs = cum_log_probs ,
2199- gathered_ids = gathered_ids ,
2200- )
2201-
2202- host = self ._async_worker_submit (
2203- _copy_tensors_to_host , gather_ids , scheduled_requests , decoder_state
2152+ gathered_ids = self ._copy_to_host (self .store ["decoder_state" ].gathered_ids )
2153+ new_output_tokens = self ._copy_to_host (self .store ["decoder_state" ].all_new_tokens )
2154+ finished_sum = self ._copy_to_host (self .store ["decoder_state" ].finished_sum )
2155+ finish_reasons = self ._copy_to_host (self .store ["decoder_state" ].finish_reasons )
2156+ sequence_lengths = self ._copy_to_host (self .store ["decoder_state" ].sequence_lengths )
2157+
2158+ log_probs = None
2159+ cum_log_probs = None
2160+ if any (request .py_return_log_probs for request in scheduled_requests .all_requests ()):
2161+ log_probs = self ._copy_to_host (self .store ["decoder_state" ].log_probs )
2162+ cum_log_probs = self ._copy_to_host (self .store ["decoder_state" ].cum_log_probs )
2163+
2164+ device = SampleStateTensors (new_tokens = self .store ["decoder_state" ].all_new_tokens )
2165+
2166+ host = SampleStateTensorsHostTRTLLM (
2167+ new_tokens = new_output_tokens ,
2168+ finished_sum = finished_sum ,
2169+ finish_reasons = finish_reasons ,
2170+ sequence_lengths = sequence_lengths ,
2171+ log_probs = log_probs ,
2172+ cum_log_probs = cum_log_probs ,
2173+ gathered_ids = gathered_ids ,
22042174 )
22052175
2206- sampler_event = torch .cuda .Event ()
2207- sampler_event .record ()
2176+ sampler_event = self ._sampler_event_get ()
22082177
22092178 self .micro_batch_idx = (self .micro_batch_idx + 1 ) % self .num_micro_batches
22102179
@@ -2228,13 +2197,7 @@ def update_requests(
22282197 if state .scheduled_requests .batch_size == 0 :
22292198 return
22302199
2231- if state .sampler_event :
2232- state .sampler_event .synchronize ()
2233-
2234- if self ._async_worker_active ():
2235- # Wait for and "unpack" the host tensors
2236- assert isinstance (state .host , futures .Future )
2237- state .host = state .host .result ()
2200+ self ._sampler_event_synchronize (state .sampler_event )
22382201
22392202 beam_width = self .beam_width (state .scheduled_requests .all_requests ())
22402203
0 commit comments