Skip to content

Commit 486179e

Browse files
committed
Refactor to pull all common code under the AsyncWorkerMixin
Signed-off-by: Dan Hansen <[email protected]>
1 parent 0664af4 commit 486179e

File tree

4 files changed

+194
-224
lines changed

4 files changed

+194
-224
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from tensorrt_llm._torch.model_config import ModelConfig
1010
from tensorrt_llm._torch.models.modeling_utils import \
1111
MODEL_CLASS_VISION_ENCODER_MAPPING
12-
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
13-
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str, confidential_compute_enabled
12+
from tensorrt_llm._utils import (confidential_compute_enabled,
13+
str_dtype_to_binding, torch_dtype_to_str)
1414
from tensorrt_llm.bindings.executor import DecodingMode
1515
from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig,
1616
EagleDecodingConfig, KvCacheConfig,
@@ -817,8 +817,7 @@ def create_py_executor_instance(
817817
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
818818
max_batch_size: int,
819819
speculative_config: SpeculativeConfig,
820-
max_beam_width: int,
821-
use_host_copy_thread: bool):
820+
max_beam_width: int, use_async_worker: bool):
822821
max_num_sequences = max_batch_size * mapping.pp_size
823822
max_draft_len = (0 if speculative_config is None else
824823
speculative_config.max_draft_len)
@@ -831,7 +830,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
831830
max_total_draft_tokens=max_total_draft_tokens,
832831
max_num_sequences=max_num_sequences,
833832
max_beam_width=max_beam_width,
834-
use_host_copy_thread=use_host_copy_thread,
833+
use_async_worker=use_async_worker,
835834
)
836835

837836

@@ -842,15 +841,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
842841
speculative_config: SpeculativeConfig,
843842
decoding_config: trtllm.DecodingConfig,
844843
kv_cache_config: KvCacheConfig):
845-
use_host_copy_thread = confidential_compute_enabled()
844+
use_async_worker = confidential_compute_enabled()
846845

847846
sampler_args = create_torch_sampler_args(
848847
mapping,
849848
max_seq_len=engine.max_seq_len,
850849
max_batch_size=max_batch_size,
851850
speculative_config=speculative_config,
852851
max_beam_width=max_beam_width,
853-
use_host_copy_thread=use_host_copy_thread)
852+
use_async_worker=use_async_worker)
854853
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
855854
max_beam_width=max_beam_width)
856855
if mapping.cp_config.get('cp_type') == CpType.STAR:
@@ -877,7 +876,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
877876
max_beam_width=max_beam_width,
878877
decoding_config=decoding_config,
879878
kv_cache_config=kv_cache_config,
880-
use_host_copy_thread=use_host_copy_thread)
879+
use_async_worker=use_async_worker)
881880
if not engine.model.model_config.is_generation:
882881
# NOTE: choose sampler based on model type
883882
return EarlyStopSampler()

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -475,9 +475,9 @@ def shutdown(self):
475475
del self.model_engine
476476
if self.draft_model_engine is not None:
477477
del self.draft_model_engine
478-
# Stop the sampler's host copy thread, if it was used
479-
if hasattr(self.sampler, 'stop_host_copy_thread'):
480-
self.sampler.stop_host_copy_thread()
478+
# Stop the sampler's async worker, if it was used
479+
if hasattr(self.sampler, 'async_worker_stop'):
480+
self.sampler.async_worker_stop()
481481

482482
def can_enqueue_requests(self) -> bool:
483483
"""

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,11 @@ def update_sampler_max_seq_len(max_seq_len, sampler):
179179
assert hasattr(sampler, "max_seq_len")
180180
sampler.max_seq_len = max_seq_len
181181

182-
def maybe_start_sampler_host_copy_thread(sampler):
183-
if hasattr(sampler, 'start_host_copy_thread') and sampler.use_host_copy_thread:
184-
sampler.start_host_copy_thread()
182+
183+
def maybe_start_sampler_async_worker(sampler):
184+
if hasattr(sampler, 'async_worker_start') and sampler.use_async_worker:
185+
sampler.async_worker_start()
186+
185187

186188
def get_guided_decoding_config(guided_decoding_backend: str,
187189
tokenizer: Optional[TokenizerBase] = None):
@@ -671,8 +673,8 @@ def drafting_loop_wrapper(model):
671673
_adjust_torch_mem_fraction(pytorch_backend_config)
672674

673675
# Now that we've got the instance of py_executor that we're going to keep,
674-
# start the sampler's host copy thread, if needed
675-
maybe_start_sampler_host_copy_thread(sampler)
676+
# start the sampler's async worker, if needed
677+
maybe_start_sampler_async_worker(sampler)
676678

677679
py_executor.start_worker()
678680
return py_executor

0 commit comments

Comments
 (0)