2020 MultimodalRuntimeData )
2121from tensorrt_llm .inputs .registry import (create_input_processor ,
2222 create_input_processor_with_hash )
23+ from tensorrt_llm .llmapi .llm_args import TorchLlmArgs
2324from tensorrt_llm .logger import logger
2425from tensorrt_llm .lora_helper import LoraConfig
2526from tensorrt_llm .lora_manager import LoraModelConfig
3839from ..expert_statistic import ExpertStatistic
3940from ..memory_buffer_utils import with_shared_pool
4041from ..metadata import KVCacheParams
41- from ..models .checkpoints .base_checkpoint_loader import BaseCheckpointLoader
4242from ..models .modeling_multimodal_utils import filter_mm_token_from_input_ids
4343from ..models .modeling_utils import DecoderModelForCausalLM
4444from ..modules .fused_moe .moe_load_balancer import (MoeLoadBalancer ,
5252from ..utils import (get_model_extra_attrs ,
5353 set_per_request_piecewise_cuda_graph_flag ,
5454 set_torch_compiling , with_model_extra_attrs )
55- from .config import PyTorchConfig
55+ from .config import PyTorchConfig , _construct_checkpoint_loader
5656from .config_utils import is_mla
5757from .cuda_graph_runner import CUDAGraphRunner
5858from .guided_decoder import CapturableGuidedDecoder
@@ -131,29 +131,36 @@ def __init__(
131131 * ,
132132 model_path : str ,
133133 pytorch_backend_config : PyTorchConfig ,
134- checkpoint_loader : BaseCheckpointLoader ,
135- batch_size : int = 8 ,
136- max_beam_width : int = 1 ,
137- max_num_tokens : int = 8192 ,
138- max_seq_len : Optional [int ] = None ,
139134 mapping : Optional [Mapping ] = None ,
140135 attn_runtime_features : Optional [AttentionRuntimeFeatures ] = None ,
141136 dist : Optional [MPIDist ] = None ,
142137 spec_config : Optional ["DecodingBaseConfig" ] = None ,
143- sparse_attention_config : Optional ["SparseAttentionConfig" ] = None ,
144- lora_config : Optional [LoraConfig ] = None ,
145138 is_draft_model : bool = False ,
146139 drafting_loop_wrapper : Optional [Callable [[torch .nn .Module ],
147140 torch .nn .Module ]] = None ,
148141 model : Optional [torch .nn .Module ] = None ,
142+ llm_args : Optional [TorchLlmArgs ] = None ,
149143 ):
144+ assert llm_args is not None , "llm_args must be provided for PyTorchModelEngine"
145+
150146 self .forward_pass_callable = None
151147 self .ub_buffers = None
152- self .batch_size = batch_size
148+ (
149+ max_beam_width ,
150+ max_num_tokens ,
151+ max_seq_len ,
152+ max_batch_size ,
153+ ) = llm_args .get_runtime_sizes ()
154+
155+ self .batch_size = max_batch_size
153156 self .max_num_tokens = max_num_tokens
154157 self .max_seq_len = max_seq_len
155158 self .max_beam_width = max_beam_width
156159
160+ checkpoint_loader = _construct_checkpoint_loader (
161+ llm_args .backend , llm_args .checkpoint_loader ,
162+ llm_args .checkpoint_format )
163+
157164 self .mapping = mapping
158165 if mapping .has_pp ():
159166 init_pp_comm (mapping )
@@ -171,7 +178,7 @@ def __init__(
171178 spec_config .max_total_draft_tokens = 0
172179 self .spec_config = spec_config
173180 self .is_spec_decode = spec_config is not None
174- self .sparse_attention_config = sparse_attention_config
181+ self .sparse_attention_config = None if is_draft_model else llm_args . sparse_attention_config
175182 self .enable_spec_decode = self .is_spec_decode
176183 self .is_draft_model = is_draft_model
177184
@@ -181,13 +188,15 @@ def __init__(
181188 self .input_processor_with_hash = create_input_processor_with_hash (
182189 self .input_processor )
183190 if model is None :
191+ lora_config : Optional [
192+ LoraConfig ] = None if is_draft_model else llm_args .lora_config
184193 loader = ModelLoader (
185194 pytorch_backend_config = pytorch_backend_config ,
186195 mapping = self .mapping ,
187196 spec_config = self .spec_config ,
188197 sparse_attention_config = self .sparse_attention_config ,
189- max_num_tokens = max_num_tokens ,
190- max_seq_len = max_seq_len ,
198+ max_num_tokens = self . max_num_tokens ,
199+ max_seq_len = self . max_seq_len ,
191200 lora_config = lora_config ,
192201 )
193202 self .model , moe_load_balancer = loader .load (
@@ -273,29 +282,27 @@ def __init__(
273282
274283 self .attn_backend = get_attention_backend (
275284 pytorch_backend_config .attn_backend ,
276- sparse_attn_config = sparse_attention_config )
285+ sparse_attn_config = self . sparse_attention_config )
277286
278287 if self .is_spec_decode :
279288 self .spec_metadata = None
280289 update_spec_config_from_model_config (self .spec_config ,
281290 self .model .config )
282- max_num_draft_tokens = self .original_max_total_draft_tokens * batch_size
291+ max_num_draft_tokens = self .original_max_total_draft_tokens * self . batch_size
283292 self .draft_tokens_cuda = torch .empty ((max_num_draft_tokens , ),
284293 dtype = torch .int ,
285294 device = 'cuda' )
286295 self .gather_ids_cuda = torch .empty ((self .max_num_tokens , ),
287296 dtype = torch .int ,
288297 device = 'cuda' )
289- self .num_accepted_draft_tokens_cuda = torch .empty ((batch_size , ),
290- dtype = torch .int ,
291- device = 'cuda' )
298+ self .num_accepted_draft_tokens_cuda = torch .empty (
299+ (self .batch_size , ), dtype = torch .int , device = 'cuda' )
292300 self .previous_pos_indices_cuda = torch .empty (
293301 (self .max_num_tokens , ), dtype = torch .int , device = 'cuda' )
294302 self .previous_pos_id_offsets_cuda = torch .zeros (
295303 (self .max_num_tokens , ), dtype = torch .int , device = 'cuda' )
296- self .previous_kv_lens_offsets_cuda = torch .zeros ((batch_size , ),
297- dtype = torch .int ,
298- device = 'cuda' )
304+ self .previous_kv_lens_offsets_cuda = torch .zeros (
305+ (self .batch_size , ), dtype = torch .int , device = 'cuda' )
299306 self .without_logits = self .spec_config .spec_dec_mode .without_logits (
300307 ) or self .model_is_wrapped
301308 self .max_draft_len = spec_config .max_draft_len
0 commit comments