99from tensorrt_llm ._torch .model_config import ModelConfig
1010from tensorrt_llm ._torch .models .modeling_utils import \
1111 MODEL_CLASS_VISION_ENCODER_MAPPING
12- from tensorrt_llm ._utils import str_dtype_to_binding , torch_dtype_to_str
13- from tensorrt_llm . _utils import str_dtype_to_binding , torch_dtype_to_str , confidential_compute_enabled
12+ from tensorrt_llm ._utils import ( confidential_compute_enabled ,
13+ str_dtype_to_binding , torch_dtype_to_str )
1414from tensorrt_llm .bindings .executor import DecodingMode
1515from tensorrt_llm .llmapi .llm_args import (CacheTransceiverConfig ,
1616 EagleDecodingConfig , KvCacheConfig ,
@@ -817,8 +817,7 @@ def create_py_executor_instance(
817817def create_torch_sampler_args (mapping : Mapping , * , max_seq_len : int ,
818818 max_batch_size : int ,
819819 speculative_config : SpeculativeConfig ,
820- max_beam_width : int ,
821- use_host_copy_thread : bool ):
820+ max_beam_width : int , use_async_worker : bool ):
822821 max_num_sequences = max_batch_size * mapping .pp_size
823822 max_draft_len = (0 if speculative_config is None else
824823 speculative_config .max_draft_len )
@@ -831,7 +830,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
831830 max_total_draft_tokens = max_total_draft_tokens ,
832831 max_num_sequences = max_num_sequences ,
833832 max_beam_width = max_beam_width ,
834- use_host_copy_thread = use_host_copy_thread ,
833+ use_async_worker = use_async_worker ,
835834 )
836835
837836
@@ -842,15 +841,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
842841 speculative_config : SpeculativeConfig ,
843842 decoding_config : trtllm .DecodingConfig ,
844843 kv_cache_config : KvCacheConfig ):
845- use_host_copy_thread = confidential_compute_enabled ()
844+ use_async_worker = confidential_compute_enabled ()
846845
847846 sampler_args = create_torch_sampler_args (
848847 mapping ,
849848 max_seq_len = engine .max_seq_len ,
850849 max_batch_size = max_batch_size ,
851850 speculative_config = speculative_config ,
852851 max_beam_width = max_beam_width ,
853- use_host_copy_thread = use_host_copy_thread )
852+ use_async_worker = use_async_worker )
854853 decoding_mode = get_decoding_mode (decoding_config = decoding_config ,
855854 max_beam_width = max_beam_width )
856855 if mapping .cp_config .get ('cp_type' ) == CpType .STAR :
@@ -877,7 +876,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
877876 max_beam_width = max_beam_width ,
878877 decoding_config = decoding_config ,
879878 kv_cache_config = kv_cache_config ,
880- use_host_copy_thread = use_host_copy_thread )
879+ use_async_worker = use_async_worker )
881880 if not engine .model .model_config .is_generation :
882881 # NOTE: choose sampler based on model type
883882 return EarlyStopSampler ()
0 commit comments