@@ -144,7 +144,14 @@ def build_from_config(cls, ad_config: LlmArgs):
144144 build_and_optimize = InferenceOptimizer (factory = factory , config = ad_config .transforms )
145145
146146 # construct engine
147- return cls (build_and_optimize , seq_info , device , max_beam_width , reporting_info )
147+ return cls (
148+ build_and_optimize ,
149+ seq_info ,
150+ device ,
151+ max_beam_width ,
152+ ad_config .sampler_type ,
153+ reporting_info ,
154+ )
148155
149156 @torch .inference_mode ()
150157 def __init__ (
@@ -153,6 +160,7 @@ def __init__(
153160 seq_info : SequenceInfo ,
154161 device : DeviceLikeType ,
155162 max_beam_width : int = 1 ,
163+ sampler_type : SamplerType = SamplerType .TorchSampler ,
156164 reporting_info : ReportingInfo = ReportingInfo (),
157165 ) -> None :
158166 """Initialize the engine with model and sequence information."""
@@ -168,6 +176,7 @@ def __init__(
168176 self .llm_args .batch_wait_timeout_iters = 0
169177 self .llm_args .batch_wait_max_tokens_ratio = 0.0
170178 self .llm_args .max_num_tokens = seq_info .max_num_tokens
179+ self .sampler_type = sampler_type
171180 self .iter_counter = 0
172181 self .iter_states = {}
173182
@@ -301,10 +310,12 @@ def _compute_logits(self) -> List[torch.Tensor]:
301310 logits : torch .Tensor = self .model (** self .cache_seq_interface .named_args )[0 ]
302311
303312 # Ensure logits are float32 as TRTLLMSampler expects float32
304- if logits .dtype != torch .float32 :
305- print ("Changing logits dtype to float32" )
306- print (f"Old logits.dtype: { logits .dtype } " )
307- logits = logits .float ()
313+ # TODO(govind): Should this be put into the AD graph so it can be fused with other operations?
314+ if self .sampler_type == SamplerType .TRTLLMSampler and logits .dtype != torch .float32 :
315+ ad_logger .info (
316+ f"Logits type { logits .dtype } is not supported by TRTLLMSampler. Casting to float32."
317+ )
318+ logits = logits .to (torch .float32 )
308319
309320 # return a list of tensors
310321 return self .cache_seq_interface .info .unnest_sequences (logits )
@@ -351,6 +362,57 @@ def __init__(self, ad_config: LlmArgs):
351362 self .config .num_attention_heads = factory .num_attention_heads
352363
353364
365+ def get_torch_dtype (ad_config : LlmArgs ):
366+ # if the model dtype is "auto", we infer it from the model config
367+ model_dtype = ad_config .dtype
368+ if model_dtype == "auto" :
369+ model_dtype = ad_config .create_factory ().dtype
370+ if isinstance (model_dtype , str ):
371+ model_dtype = str_dtype_to_torch (model_dtype )
372+ return model_dtype
373+
374+
375+ def instantiate_sampler (
376+ ad_config : LlmArgs ,
377+ max_num_sequences : int ,
378+ max_draft_len : int ,
379+ max_total_draft_tokens : int ,
380+ dist_mapping : Mapping ,
381+ ):
382+ if ad_config .sampler_type == SamplerType .TorchSampler :
383+ # search sampler with speculative decoding
384+ sampler_args = TorchSampler .Args (
385+ max_seq_len = ad_config .max_seq_len ,
386+ max_draft_len = max_draft_len ,
387+ max_total_draft_tokens = max_total_draft_tokens ,
388+ max_num_sequences = max_num_sequences ,
389+ max_beam_width = ad_config .max_beam_width ,
390+ disable_overlap_scheduler = ad_config .disable_overlap_scheduler ,
391+ )
392+ sampler = TorchSampler (sampler_args )
393+
394+ elif ad_config .sampler_type == SamplerType .TRTLLMSampler :
395+ tllm_model_config = TRTLLMSamplerModelConfig (ad_config = ad_config )
396+ decoding_mode = get_decoding_mode (ad_config .decoding_config , ad_config .max_beam_width )
397+ model_dtype = get_torch_dtype (ad_config )
398+ sampler = TRTLLMSampler (
399+ model = tllm_model_config ,
400+ model_dtype = model_dtype ,
401+ mapping = dist_mapping ,
402+ decoding_mode = decoding_mode ,
403+ disable_overlap_scheduler = ad_config .disable_overlap_scheduler ,
404+ max_seq_len = ad_config .max_seq_len ,
405+ max_batch_size = ad_config .max_batch_size ,
406+ max_beam_width = ad_config .max_beam_width ,
407+ decoding_config = ad_config .decoding_config ,
408+ kv_cache_config = ad_config .kv_cache_config ,
409+ )
410+ else :
411+ raise ValueError (f"Sampler type { ad_config .sampler_type } is not supported." )
412+
413+ return sampler
414+
415+
354416def create_autodeploy_executor (ad_config : LlmArgs , tokenizer : Optional [TokenizerBase ] = None ):
355417 """Create an AutoDeploy executor from the given configuration and tokenizer.
356418 The tokenizer is required for guided decoding.
@@ -447,42 +509,14 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
447509 )
448510 scheduler = SimpleScheduler (capacitor_scheduler , mb_scheduler )
449511
450- if ad_config .sampler_type == SamplerType .TorchSampler :
451- # search sampler with speculative decoding
452- sampler_args = TorchSampler .Args (
453- max_seq_len = ad_config .max_seq_len ,
454- max_draft_len = max_draft_len ,
455- max_total_draft_tokens = max_total_draft_tokens ,
456- max_num_sequences = max_num_sequences ,
457- max_beam_width = ad_config .max_beam_width ,
458- disable_overlap_scheduler = ad_config .disable_overlap_scheduler ,
459- )
460- sampler = TorchSampler (sampler_args )
512+ sampler = instantiate_sampler (
513+ ad_config = ad_config ,
514+ max_num_sequences = max_num_sequences ,
515+ max_draft_len = max_draft_len ,
516+ max_total_draft_tokens = max_total_draft_tokens ,
517+ dist_mapping = dist_mapping ,
518+ )
461519
462- elif ad_config .sampler_type == SamplerType .TRTLLMSampler :
463- tllm_model_config = TRTLLMSamplerModelConfig (ad_config = ad_config )
464- decoding_mode = get_decoding_mode (ad_config .decoding_config , ad_config .max_beam_width )
465- # if the model dtype is "auto", we infer it from the model config
466- model_dtype = ad_config .dtype
467- print (f"model_dtype: { model_dtype } " )
468- if model_dtype == "auto" :
469- model_dtype = ad_config .create_factory ().dtype
470- print (f"model_dtype was auto. Setting to: { model_dtype } " )
471- if isinstance (model_dtype , str ):
472- model_dtype = str_dtype_to_torch (model_dtype )
473- print (f"model_dtype was string. Setting to: { model_dtype } " )
474- sampler = TRTLLMSampler (
475- model = tllm_model_config ,
476- model_dtype = model_dtype ,
477- mapping = dist_mapping ,
478- decoding_mode = decoding_mode ,
479- disable_overlap_scheduler = ad_config .disable_overlap_scheduler ,
480- max_seq_len = ad_config .max_seq_len ,
481- max_batch_size = ad_config .max_batch_size ,
482- max_beam_width = ad_config .max_beam_width ,
483- decoding_config = ad_config .decoding_config ,
484- kv_cache_config = ad_config .kv_cache_config ,
485- )
486520 # Guided (istructured) decoding.
487521 guided_decoder = None
488522 if (
0 commit comments