1717from abc import ABC , abstractmethod
1818from collections import defaultdict
1919from collections .abc import Iterable
20- from concurrent . futures import Future , ThreadPoolExecutor
20+ from concurrent import futures
2121from dataclasses import dataclass
2222from itertools import repeat
2323from typing import Any , Callable , List , Optional , TypeVar , cast
8282
8383@dataclass (kw_only = True )
8484class SampleStateTensors :
85- new_tokens : torch .Tensor | Future [torch .Tensor ]
85+ new_tokens : torch .Tensor | futures . Future [torch .Tensor ]
8686 log_probs : torch .Tensor | None = None
8787
8888 def values (self ):
@@ -574,12 +574,12 @@ class AsyncWorkerMixin:
574574 def _async_worker_active (self ) -> bool :
575575 return self ._async_worker is not None
576576
577- def _async_worker_init (self , use_async_worker : bool ):
578- self .use_async_worker = use_async_worker
577+ def _async_worker_init (self , enable_async_worker : bool ):
578+ self .enable_async_worker = enable_async_worker
579579 self ._async_worker = None
580580
581581 def async_worker_start (self ):
582- assert self .use_async_worker
582+ assert self .enable_async_worker
583583 assert not self ._async_worker_active ()
584584
585585 def _async_worker_initializer (device_id ):
@@ -590,7 +590,7 @@ def _async_worker_initializer(device_id):
590590 # blocking copies from gating subsequent async work
591591 torch .cuda .set_stream (torch .cuda .Stream ())
592592
593- self ._async_worker = ThreadPoolExecutor (
593+ self ._async_worker = futures . ThreadPoolExecutor (
594594 max_workers = 1 ,
595595 initializer = _async_worker_initializer ,
596596 initargs = (torch .cuda .current_device (),),
@@ -653,7 +653,7 @@ class Args:
653653 max_num_sequences : int
654654 max_beam_width : int
655655 max_total_draft_tokens : int
656- use_async_worker : Optional [bool ] = False
656+ enable_async_worker : Optional [bool ] = False
657657
658658 def __init__ (self , args : Args ):
659659 self .max_seq_len = args .max_seq_len
@@ -674,7 +674,7 @@ def __init__(self, args: Args):
674674 self ._global_seed = 42
675675 self ._generator = None
676676
677- self ._async_worker_init (args .use_async_worker )
677+ self ._async_worker_init (args .enable_async_worker )
678678
679679 def get_generator (self , device : torch .device ) -> torch .Generator :
680680 """Get a deterministic generator for the specified device.
@@ -755,9 +755,8 @@ def handle_logprobs(
755755 if request .py_return_log_probs :
756756 if self ._async_worker_active ():
757757 # These should be futures if we used the async worker
758- assert isinstance (request .py_topk_logprobs_values , Future ) and isinstance (
759- request .py_topk_logprobs_vals , Future
760- )
758+ assert isinstance (request .py_topk_logprobs_values , futures .Future )
759+ assert isinstance (request .py_topk_logprobs_vals , futures .Future )
761760 topk_log_probs_vals = request .py_topk_logprobs_vals .result ()
762761 topk_log_probs_indices = request .py_topk_logprobs_indices .result ()
763762 else :
@@ -1079,7 +1078,7 @@ def update_requests(
10791078 assert state .host is not None
10801079
10811080 if self ._async_worker_active ():
1082- assert isinstance (state .host .new_tokens , Future )
1081+ assert isinstance (state .host .new_tokens , futures . Future )
10831082 new_tokens = state .host .new_tokens .result ()
10841083 else :
10851084 new_tokens = state .host .new_tokens
@@ -1686,7 +1685,9 @@ class SampleStateTensorsHostTRTLLM(SampleStateTensors):
16861685class SampleStateTRTLLM (SampleState ):
16871686 finalize_events : dict [str , CudaEvent ] | None = None
16881687 """`Optional` to accommodate `_forward_step_inter_pp` which creates a `SampleState` without `finalize_events`"""
1689- host : Optional [SampleStateTensorsHostTRTLLM | Future [SampleStateTensorsHostTRTLLM ]] = None
1688+ host : Optional [SampleStateTensorsHostTRTLLM | futures .Future [SampleStateTensorsHostTRTLLM ]] = (
1689+ None
1690+ )
16901691
16911692
16921693class TRTLLMSampler (Sampler , AsyncWorkerMixin ):
@@ -1709,7 +1710,7 @@ def __init__(
17091710 max_beam_width : int ,
17101711 decoding_config : Optional [DecodingConfig ] = None ,
17111712 kv_cache_config : Optional [KvCacheConfig ] = None ,
1712- use_async_worker : Optional [bool ] = False ,
1713+ enable_async_worker : Optional [bool ] = False ,
17131714 ):
17141715 vocab_size = model .config .vocab_size
17151716 num_hidden_layers = model .config .num_hidden_layers
@@ -1760,7 +1761,7 @@ def __init__(
17601761 self ._initialize_store ()
17611762 self ._instantiate_algorithms ()
17621763
1763- self ._async_worker_init (use_async_worker )
1764+ self ._async_worker_init (enable_async_worker )
17641765
17651766 def _initialize_store (self ):
17661767 torch_stream = torch .cuda .current_stream ().cuda_stream
@@ -1983,7 +1984,7 @@ def update_requests(
19831984
19841985 if self ._async_worker_active ():
19851986 # Wait for and "unpack" the host tensors
1986- assert isinstance (state .host , Future )
1987+ assert isinstance (state .host , futures . Future )
19871988 state .host = state .host .result ()
19881989
19891990 beam_width = self .beam_width (state .scheduled_requests .all_requests ())
0 commit comments