1010from tensorrt_llm ._torch .models .modeling_utils import \
1111 MODEL_CLASS_VISION_ENCODER_MAPPING
1212from tensorrt_llm ._utils import str_dtype_to_binding , torch_dtype_to_str
13+ from tensorrt_llm ._utils import str_dtype_to_binding , torch_dtype_to_str , confidential_compute_enabled
1314from tensorrt_llm .bindings .executor import DecodingMode
1415from tensorrt_llm .llmapi .llm_args import (CacheTransceiverConfig ,
1516 EagleDecodingConfig , KvCacheConfig ,
@@ -816,7 +817,8 @@ def create_py_executor_instance(
816817def create_torch_sampler_args (mapping : Mapping , * , max_seq_len : int ,
817818 max_batch_size : int ,
818819 speculative_config : SpeculativeConfig ,
819- max_beam_width : int ):
820+ max_beam_width : int ,
821+ use_host_copy_thread : bool ):
820822 max_num_sequences = max_batch_size * mapping .pp_size
821823 max_draft_len = (0 if speculative_config is None else
822824 speculative_config .max_draft_len )
@@ -829,6 +831,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
829831 max_total_draft_tokens = max_total_draft_tokens ,
830832 max_num_sequences = max_num_sequences ,
831833 max_beam_width = max_beam_width ,
834+ use_host_copy_thread = use_host_copy_thread ,
832835 )
833836
834837
@@ -839,12 +842,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
839842 speculative_config : SpeculativeConfig ,
840843 decoding_config : trtllm .DecodingConfig ,
841844 kv_cache_config : KvCacheConfig ):
845+ use_host_copy_thread = confidential_compute_enabled ()
846+
842847 sampler_args = create_torch_sampler_args (
843848 mapping ,
844849 max_seq_len = engine .max_seq_len ,
845850 max_batch_size = max_batch_size ,
846851 speculative_config = speculative_config ,
847- max_beam_width = max_beam_width )
852+ max_beam_width = max_beam_width ,
853+ use_host_copy_thread = use_host_copy_thread )
848854 decoding_mode = get_decoding_mode (decoding_config = decoding_config ,
849855 max_beam_width = max_beam_width )
850856 if mapping .cp_config .get ('cp_type' ) == CpType .STAR :
@@ -870,7 +876,8 @@ def instantiate_sampler(engine: PyTorchModelEngine,
870876 max_batch_size = max_batch_size ,
871877 max_beam_width = max_beam_width ,
872878 decoding_config = decoding_config ,
873- kv_cache_config = kv_cache_config )
879+ kv_cache_config = kv_cache_config ,
880+ use_host_copy_thread = use_host_copy_thread )
874881 if not engine .model .model_config .is_generation :
875882 # NOTE: choose sampler based on model type
876883 return EarlyStopSampler ()
0 commit comments