Skip to content

Commit 3bfe777

Browse files
some cleanup of the impl and the test
Signed-off-by: Govind Ramnarayan <[email protected]>
1 parent 877cb5f commit 3bfe777

File tree

5 files changed

+129
-61
lines changed

5 files changed

+129
-61
lines changed

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def build_and_load_model(self, device: DeviceLikeType) -> nn.Module:
345345
"trust_remote_code": True,
346346
"tp_plan": "auto",
347347
**unused_kwargs,
348-
# "dtype": "auto", # takes precedence over unused_kwargs! -- REMOVED
348+
"dtype": "auto", # takes precedence over unused_kwargs!
349349
},
350350
)
351351
model.eval()

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 74 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,14 @@ def build_from_config(cls, ad_config: LlmArgs):
144144
build_and_optimize = InferenceOptimizer(factory=factory, config=ad_config.transforms)
145145

146146
# construct engine
147-
return cls(build_and_optimize, seq_info, device, max_beam_width, reporting_info)
147+
return cls(
148+
build_and_optimize,
149+
seq_info,
150+
device,
151+
max_beam_width,
152+
ad_config.sampler_type,
153+
reporting_info,
154+
)
148155

149156
@torch.inference_mode()
150157
def __init__(
@@ -153,6 +160,7 @@ def __init__(
153160
seq_info: SequenceInfo,
154161
device: DeviceLikeType,
155162
max_beam_width: int = 1,
163+
sampler_type: SamplerType = SamplerType.TorchSampler,
156164
reporting_info: ReportingInfo = ReportingInfo(),
157165
) -> None:
158166
"""Initialize the engine with model and sequence information."""
@@ -168,6 +176,7 @@ def __init__(
168176
self.llm_args.batch_wait_timeout_iters = 0
169177
self.llm_args.batch_wait_max_tokens_ratio = 0.0
170178
self.llm_args.max_num_tokens = seq_info.max_num_tokens
179+
self.sampler_type = sampler_type
171180
self.iter_counter = 0
172181
self.iter_states = {}
173182

@@ -301,10 +310,12 @@ def _compute_logits(self) -> List[torch.Tensor]:
301310
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
302311

303312
# Ensure logits are float32 as TRTLLMSampler expects float32
304-
if logits.dtype != torch.float32:
305-
print("Changing logits dtype to float32")
306-
print(f"Old logits.dtype: {logits.dtype}")
307-
logits = logits.float()
313+
# TODO(govind): Should this be put into the AD graph so it can be fused with other operations?
314+
if self.sampler_type == SamplerType.TRTLLMSampler and logits.dtype != torch.float32:
315+
ad_logger.info(
316+
f"Logits type {logits.dtype} is not supported by TRTLLMSampler. Casting to float32."
317+
)
318+
logits = logits.to(torch.float32)
308319

309320
# return a list of tensors
310321
return self.cache_seq_interface.info.unnest_sequences(logits)
@@ -351,6 +362,57 @@ def __init__(self, ad_config: LlmArgs):
351362
self.config.num_attention_heads = factory.num_attention_heads
352363

353364

365+
def get_torch_dtype(ad_config: LlmArgs):
366+
# if the model dtype is "auto", we infer it from the model config
367+
model_dtype = ad_config.dtype
368+
if model_dtype == "auto":
369+
model_dtype = ad_config.create_factory().dtype
370+
if isinstance(model_dtype, str):
371+
model_dtype = str_dtype_to_torch(model_dtype)
372+
return model_dtype
373+
374+
375+
def instantiate_sampler(
376+
ad_config: LlmArgs,
377+
max_num_sequences: int,
378+
max_draft_len: int,
379+
max_total_draft_tokens: int,
380+
dist_mapping: Mapping,
381+
):
382+
if ad_config.sampler_type == SamplerType.TorchSampler:
383+
# search sampler with speculative decoding
384+
sampler_args = TorchSampler.Args(
385+
max_seq_len=ad_config.max_seq_len,
386+
max_draft_len=max_draft_len,
387+
max_total_draft_tokens=max_total_draft_tokens,
388+
max_num_sequences=max_num_sequences,
389+
max_beam_width=ad_config.max_beam_width,
390+
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
391+
)
392+
sampler = TorchSampler(sampler_args)
393+
394+
elif ad_config.sampler_type == SamplerType.TRTLLMSampler:
395+
tllm_model_config = TRTLLMSamplerModelConfig(ad_config=ad_config)
396+
decoding_mode = get_decoding_mode(ad_config.decoding_config, ad_config.max_beam_width)
397+
model_dtype = get_torch_dtype(ad_config)
398+
sampler = TRTLLMSampler(
399+
model=tllm_model_config,
400+
model_dtype=model_dtype,
401+
mapping=dist_mapping,
402+
decoding_mode=decoding_mode,
403+
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
404+
max_seq_len=ad_config.max_seq_len,
405+
max_batch_size=ad_config.max_batch_size,
406+
max_beam_width=ad_config.max_beam_width,
407+
decoding_config=ad_config.decoding_config,
408+
kv_cache_config=ad_config.kv_cache_config,
409+
)
410+
else:
411+
raise ValueError(f"Sampler type {ad_config.sampler_type} is not supported.")
412+
413+
return sampler
414+
415+
354416
def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[TokenizerBase] = None):
355417
"""Create an AutoDeploy executor from the given configuration and tokenizer.
356418
The tokenizer is required for guided decoding.
@@ -447,42 +509,14 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
447509
)
448510
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)
449511

450-
if ad_config.sampler_type == SamplerType.TorchSampler:
451-
# search sampler with speculative decoding
452-
sampler_args = TorchSampler.Args(
453-
max_seq_len=ad_config.max_seq_len,
454-
max_draft_len=max_draft_len,
455-
max_total_draft_tokens=max_total_draft_tokens,
456-
max_num_sequences=max_num_sequences,
457-
max_beam_width=ad_config.max_beam_width,
458-
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
459-
)
460-
sampler = TorchSampler(sampler_args)
512+
sampler = instantiate_sampler(
513+
ad_config=ad_config,
514+
max_num_sequences=max_num_sequences,
515+
max_draft_len=max_draft_len,
516+
max_total_draft_tokens=max_total_draft_tokens,
517+
dist_mapping=dist_mapping,
518+
)
461519

462-
elif ad_config.sampler_type == SamplerType.TRTLLMSampler:
463-
tllm_model_config = TRTLLMSamplerModelConfig(ad_config=ad_config)
464-
decoding_mode = get_decoding_mode(ad_config.decoding_config, ad_config.max_beam_width)
465-
# if the model dtype is "auto", we infer it from the model config
466-
model_dtype = ad_config.dtype
467-
print(f"model_dtype: {model_dtype}")
468-
if model_dtype == "auto":
469-
model_dtype = ad_config.create_factory().dtype
470-
print(f"model_dtype was auto. Setting to: {model_dtype}")
471-
if isinstance(model_dtype, str):
472-
model_dtype = str_dtype_to_torch(model_dtype)
473-
print(f"model_dtype was string. Setting to: {model_dtype}")
474-
sampler = TRTLLMSampler(
475-
model=tllm_model_config,
476-
model_dtype=model_dtype,
477-
mapping=dist_mapping,
478-
decoding_mode=decoding_mode,
479-
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
480-
max_seq_len=ad_config.max_seq_len,
481-
max_batch_size=ad_config.max_batch_size,
482-
max_beam_width=ad_config.max_beam_width,
483-
decoding_config=ad_config.decoding_config,
484-
kv_cache_config=ad_config.kv_cache_config,
485-
)
486520
# Guided (istructured) decoding.
487521
guided_decoder = None
488522
if (

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2588,32 +2588,14 @@ def __init__(
25882588
max_beam_width: int,
25892589
decoding_config: Optional[DecodingConfig] = None,
25902590
kv_cache_config: Optional[KvCacheConfig] = None,
2591-
logits_dtype: DataType = DataType.FLOAT,
25922591
):
25932592
vocab_size = model.config.vocab_size
25942593
num_hidden_layers = model.config.num_hidden_layers
25952594
hidden_size = model.config.hidden_size
25962595
num_heads = model.config.num_attention_heads
25972596

2598-
print(
2599-
f"vocab_size: {vocab_size}, num_hidden_layers: {num_hidden_layers}, hidden_size: {hidden_size}, \
2600-
num_heads: {num_heads}"
2601-
)
2602-
print(f"model_dtype: {model_dtype}")
2603-
print(f"mapping: {mapping}")
2604-
print(f"decoding_mode: {decoding_mode}")
2605-
print(f"disable_overlap_scheduler: {disable_overlap_scheduler}")
2606-
print(f"max_seq_len: {max_seq_len}")
2607-
print(f"max_batch_size: {max_batch_size}")
2608-
print(f"max_beam_width: {max_beam_width}")
2609-
print(f"decoding_config: {decoding_config}")
2610-
print(f"kv_cache_config: {kv_cache_config}")
2611-
26122597
self.model_datatype = torch_dtype_to_binding(model_dtype)
2613-
self.logits_datatype = logits_dtype
2614-
2615-
print(f"self.model_datatype: {self.model_datatype}")
2616-
print(f"self.logits_datatype: {self.logits_datatype}")
2598+
self.logits_datatype = DataType.FLOAT
26172599
self.decoding_mode = decoding_mode
26182600
self.decoding_config = decoding_config if decoding_config else DecodingConfig(decoding_mode)
26192601
max_attn_window = kv_cache_config.max_attention_window

tensorrt_llm/llmapi/llm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1137,7 +1137,6 @@ def __init__(self,
11371137
revision: Optional[str] = None,
11381138
tokenizer_revision: Optional[str] = None,
11391139
**kwargs: Any) -> None:
1140-
print(f"dtype: {dtype}")
11411140
super().__init__(model, tokenizer, tokenizer_mode, skip_tokenizer_init,
11421141
trust_remote_code, tensor_parallel_size, dtype,
11431142
revision, tokenizer_revision, **kwargs)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from _model_test_utils import get_small_model_config
17+
from build_and_run_ad import ExperimentConfig, main
18+
19+
from tensorrt_llm.llmapi.llm_args import SamplerType
20+
21+
22+
def test_ad_trtllm_sampler_smoke():
23+
"""Test TRTLLMSampler in AutoDeploy smoke test."""
24+
# Get small model config
25+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
26+
experiment_config = get_small_model_config(model_id)
27+
28+
# Configure for TRTLLMSampler
29+
experiment_config["args"]["runtime"] = "trtllm"
30+
experiment_config["args"]["world_size"] = 1
31+
experiment_config["args"]["sampler_type"] = SamplerType.TRTLLMSampler
32+
# experiment_config["args"]["sampler_type"] = SamplerType.TorchSampler
33+
# experiment_config["args"]["dtype"] = "float32"
34+
35+
# Setup simple prompt
36+
experiment_config["prompt"]["batch_size"] = 1
37+
experiment_config["prompt"]["queries"] = {"prompt": "What is the capital of France?"}
38+
experiment_config["prompt"]["sp_kwargs"] = {
39+
"max_tokens": 10,
40+
"temperature": 1.0,
41+
"top_k": 1,
42+
}
43+
44+
print(f"Experiment config: {experiment_config}")
45+
cfg = ExperimentConfig(**experiment_config)
46+
47+
print("Running smoke test with TRTLLMSampler...")
48+
results = main(cfg)
49+
50+
# Basic assertion that we got some output
51+
prompts_and_outputs = results["prompts_and_outputs"]
52+
assert len(prompts_and_outputs) == 1
53+
assert len(prompts_and_outputs[0][1]) > 0

0 commit comments

Comments
 (0)