Skip to content

Commit 6e31c1c

Browse files
committed
Refactor to pull all common code under the AsyncWorkerMixin
Signed-off-by: Dan Hansen <[email protected]>
1 parent c80b6ca commit 6e31c1c

File tree

4 files changed

+194
-224
lines changed

4 files changed

+194
-224
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from tensorrt_llm._torch.model_config import ModelConfig
1010
from tensorrt_llm._torch.models.modeling_utils import \
1111
MODEL_CLASS_VISION_ENCODER_MAPPING
12-
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
13-
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str, confidential_compute_enabled
12+
from tensorrt_llm._utils import (confidential_compute_enabled,
13+
str_dtype_to_binding, torch_dtype_to_str)
1414
from tensorrt_llm.bindings.executor import DecodingMode
1515
from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig,
1616
EagleDecodingConfig, KvCacheConfig,
@@ -815,8 +815,7 @@ def create_py_executor_instance(
815815
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
816816
max_batch_size: int,
817817
speculative_config: SpeculativeConfig,
818-
max_beam_width: int,
819-
use_host_copy_thread: bool):
818+
max_beam_width: int, use_async_worker: bool):
820819
max_num_sequences = max_batch_size * mapping.pp_size
821820
max_draft_len = (0 if speculative_config is None else
822821
speculative_config.max_draft_len)
@@ -829,7 +828,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
829828
max_total_draft_tokens=max_total_draft_tokens,
830829
max_num_sequences=max_num_sequences,
831830
max_beam_width=max_beam_width,
832-
use_host_copy_thread=use_host_copy_thread,
831+
use_async_worker=use_async_worker,
833832
)
834833

835834

@@ -840,15 +839,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
840839
speculative_config: SpeculativeConfig,
841840
decoding_config: trtllm.DecodingConfig,
842841
kv_cache_config: KvCacheConfig):
843-
use_host_copy_thread = confidential_compute_enabled()
842+
use_async_worker = confidential_compute_enabled()
844843

845844
sampler_args = create_torch_sampler_args(
846845
mapping,
847846
max_seq_len=engine.max_seq_len,
848847
max_batch_size=max_batch_size,
849848
speculative_config=speculative_config,
850849
max_beam_width=max_beam_width,
851-
use_host_copy_thread=use_host_copy_thread)
850+
use_async_worker=use_async_worker)
852851
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
853852
max_beam_width=max_beam_width)
854853
if mapping.cp_config.get('cp_type') == CpType.STAR:
@@ -875,7 +874,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
875874
max_beam_width=max_beam_width,
876875
decoding_config=decoding_config,
877876
kv_cache_config=kv_cache_config,
878-
use_host_copy_thread=use_host_copy_thread)
877+
use_async_worker=use_async_worker)
879878
if not engine.model.model_config.is_generation:
880879
# NOTE: choose sampler based on model type
881880
return EarlyStopSampler()

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,9 @@ def shutdown(self):
478478
del self.model_engine
479479
if self.draft_model_engine is not None:
480480
del self.draft_model_engine
481-
# Stop the sampler's host copy thread, if it was used
482-
if hasattr(self.sampler, 'stop_host_copy_thread'):
483-
self.sampler.stop_host_copy_thread()
481+
# Stop the sampler's async worker, if it was used
482+
if hasattr(self.sampler, 'async_worker_stop'):
483+
self.sampler.async_worker_stop()
484484

485485
def can_enqueue_requests(self) -> bool:
486486
"""

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,11 @@ def update_sampler_max_seq_len(max_seq_len, sampler):
179179
assert hasattr(sampler, "max_seq_len")
180180
sampler.max_seq_len = max_seq_len
181181

182-
def maybe_start_sampler_host_copy_thread(sampler):
183-
if hasattr(sampler, 'start_host_copy_thread') and sampler.use_host_copy_thread:
184-
sampler.start_host_copy_thread()
182+
183+
def maybe_start_sampler_async_worker(sampler):
184+
if hasattr(sampler, 'async_worker_start') and sampler.use_async_worker:
185+
sampler.async_worker_start()
186+
185187

186188
def get_guided_decoding_config(guided_decoding_backend: str,
187189
tokenizer: Optional[TokenizerBase] = None):
@@ -674,8 +676,8 @@ def drafting_loop_wrapper(model):
674676
_adjust_torch_mem_fraction(pytorch_backend_config)
675677

676678
# Now that we've got the instance of py_executor that we're going to keep,
677-
# start the sampler's host copy thread, if needed
678-
maybe_start_sampler_host_copy_thread(sampler)
679+
# start the sampler's async worker, if needed
680+
maybe_start_sampler_async_worker(sampler)
679681

680682
py_executor.start_worker()
681683
return py_executor

0 commit comments

Comments
 (0)