Skip to content

Commit 0664af4

Browse files
committed
[https:/nvbugs/5508301][feat] Move D->H copies to a worker thread when confidential compute is active
Signed-off-by: Dan Hansen <[email protected]>
1 parent 2420918 commit 0664af4

File tree

5 files changed

+223
-46
lines changed

5 files changed

+223
-46
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tensorrt_llm._torch.models.modeling_utils import \
1111
MODEL_CLASS_VISION_ENCODER_MAPPING
1212
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
13+
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str, confidential_compute_enabled
1314
from tensorrt_llm.bindings.executor import DecodingMode
1415
from tensorrt_llm.llmapi.llm_args import (CacheTransceiverConfig,
1516
EagleDecodingConfig, KvCacheConfig,
@@ -816,7 +817,8 @@ def create_py_executor_instance(
816817
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
817818
max_batch_size: int,
818819
speculative_config: SpeculativeConfig,
819-
max_beam_width: int):
820+
max_beam_width: int,
821+
use_host_copy_thread: bool):
820822
max_num_sequences = max_batch_size * mapping.pp_size
821823
max_draft_len = (0 if speculative_config is None else
822824
speculative_config.max_draft_len)
@@ -829,6 +831,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
829831
max_total_draft_tokens=max_total_draft_tokens,
830832
max_num_sequences=max_num_sequences,
831833
max_beam_width=max_beam_width,
834+
use_host_copy_thread=use_host_copy_thread,
832835
)
833836

834837

@@ -839,12 +842,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
839842
speculative_config: SpeculativeConfig,
840843
decoding_config: trtllm.DecodingConfig,
841844
kv_cache_config: KvCacheConfig):
845+
use_host_copy_thread = confidential_compute_enabled()
846+
842847
sampler_args = create_torch_sampler_args(
843848
mapping,
844849
max_seq_len=engine.max_seq_len,
845850
max_batch_size=max_batch_size,
846851
speculative_config=speculative_config,
847-
max_beam_width=max_beam_width)
852+
max_beam_width=max_beam_width,
853+
use_host_copy_thread=use_host_copy_thread)
848854
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
849855
max_beam_width=max_beam_width)
850856
if mapping.cp_config.get('cp_type') == CpType.STAR:
@@ -870,7 +876,8 @@ def instantiate_sampler(engine: PyTorchModelEngine,
870876
max_batch_size=max_batch_size,
871877
max_beam_width=max_beam_width,
872878
decoding_config=decoding_config,
873-
kv_cache_config=kv_cache_config)
879+
kv_cache_config=kv_cache_config,
880+
use_host_copy_thread=use_host_copy_thread)
874881
if not engine.model.model_config.is_generation:
875882
# NOTE: choose sampler based on model type
876883
return EarlyStopSampler()

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,9 @@ def shutdown(self):
475475
del self.model_engine
476476
if self.draft_model_engine is not None:
477477
del self.draft_model_engine
478+
# Stop the sampler's host copy thread, if it was used
479+
if hasattr(self.sampler, 'stop_host_copy_thread'):
480+
self.sampler.stop_host_copy_thread()
478481

479482
def can_enqueue_requests(self) -> bool:
480483
"""

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def update_sampler_max_seq_len(max_seq_len, sampler):
179179
assert hasattr(sampler, "max_seq_len")
180180
sampler.max_seq_len = max_seq_len
181181

182+
def maybe_start_sampler_host_copy_thread(sampler):
183+
if hasattr(sampler, 'start_host_copy_thread') and sampler.use_host_copy_thread:
184+
sampler.start_host_copy_thread()
182185

183186
def get_guided_decoding_config(guided_decoding_backend: str,
184187
tokenizer: Optional[TokenizerBase] = None):
@@ -667,5 +670,9 @@ def drafting_loop_wrapper(model):
667670

668671
_adjust_torch_mem_fraction(pytorch_backend_config)
669672

673+
# Now that we've got the instance of py_executor that we're going to keep,
674+
# start the sampler's host copy thread, if needed
675+
maybe_start_sampler_host_copy_thread(sampler)
676+
670677
py_executor.start_worker()
671678
return py_executor

0 commit comments

Comments
 (0)