99from tensorrt_llm ._torch .model_config import ModelConfig
1010from tensorrt_llm ._torch .models .modeling_utils import \
1111 MODEL_CLASS_VISION_ENCODER_MAPPING
12- from tensorrt_llm ._utils import str_dtype_to_binding , torch_dtype_to_str
13- from tensorrt_llm . _utils import str_dtype_to_binding , torch_dtype_to_str , confidential_compute_enabled
12+ from tensorrt_llm ._utils import ( confidential_compute_enabled ,
13+ str_dtype_to_binding , torch_dtype_to_str )
1414from tensorrt_llm .bindings .executor import DecodingMode
1515from tensorrt_llm .llmapi .llm_args import (CacheTransceiverConfig ,
1616 EagleDecodingConfig , KvCacheConfig ,
@@ -815,8 +815,7 @@ def create_py_executor_instance(
815815def create_torch_sampler_args (mapping : Mapping , * , max_seq_len : int ,
816816 max_batch_size : int ,
817817 speculative_config : SpeculativeConfig ,
818- max_beam_width : int ,
819- use_host_copy_thread : bool ):
818+ max_beam_width : int , use_async_worker : bool ):
820819 max_num_sequences = max_batch_size * mapping .pp_size
821820 max_draft_len = (0 if speculative_config is None else
822821 speculative_config .max_draft_len )
@@ -829,7 +828,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
829828 max_total_draft_tokens = max_total_draft_tokens ,
830829 max_num_sequences = max_num_sequences ,
831830 max_beam_width = max_beam_width ,
832- use_host_copy_thread = use_host_copy_thread ,
831+ use_async_worker = use_async_worker ,
833832 )
834833
835834
@@ -840,15 +839,15 @@ def instantiate_sampler(engine: PyTorchModelEngine,
840839 speculative_config : SpeculativeConfig ,
841840 decoding_config : trtllm .DecodingConfig ,
842841 kv_cache_config : KvCacheConfig ):
843- use_host_copy_thread = confidential_compute_enabled ()
842+ use_async_worker = confidential_compute_enabled ()
844843
845844 sampler_args = create_torch_sampler_args (
846845 mapping ,
847846 max_seq_len = engine .max_seq_len ,
848847 max_batch_size = max_batch_size ,
849848 speculative_config = speculative_config ,
850849 max_beam_width = max_beam_width ,
851- use_host_copy_thread = use_host_copy_thread )
850+ use_async_worker = use_async_worker )
852851 decoding_mode = get_decoding_mode (decoding_config = decoding_config ,
853852 max_beam_width = max_beam_width )
854853 if mapping .cp_config .get ('cp_type' ) == CpType .STAR :
@@ -875,7 +874,7 @@ def instantiate_sampler(engine: PyTorchModelEngine,
875874 max_beam_width = max_beam_width ,
876875 decoding_config = decoding_config ,
877876 kv_cache_config = kv_cache_config ,
878- use_host_copy_thread = use_host_copy_thread )
877+ use_async_worker = use_async_worker )
879878 if not engine .model .model_config .is_generation :
880879 # NOTE: choose sampler based on model type
881880 return EarlyStopSampler ()
0 commit comments