Skip to content

Commit 4627ad1

Browse files
TRTLLMSampler init commit; working for AutoDeploy by upcasting logits at end of forward() function to fp32. This is hacky and should be resolved
Signed-off-by: Govind Ramnarayan <[email protected]>
1 parent cee7071 commit 4627ad1

File tree

3 files changed

+163
-13
lines changed

3 files changed

+163
-13
lines changed

tensorrt_llm/_torch/auto_deploy/llm_args.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tensorrt_llm.models.modeling_utils import QuantConfig
1010

11-
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, _ParallelConfig
11+
from ...llmapi.llm_args import BaseLlmArgs, BuildConfig, KvCacheConfig, SamplerType, _ParallelConfig
1212
from .models import ModelFactory, ModelFactoryRegistry
1313
from .utils._config import DynamicYamlMixInForSettings
1414
from .utils.logger import ad_logger
@@ -130,6 +130,11 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
130130
"supported in AutoDeploy.",
131131
)
132132

133+
sampler_type: Union[str, SamplerType] = Field(
134+
default=SamplerType.TorchSampler,
135+
description="The type of sampler to use. Options are TRTLLMSampler or TorchSampler. Defaults to TorchSampler.",
136+
)
137+
133138
# NOTE: we do not support copy_on_partial_reuse in AutoDeploy yet
134139
# see https://github.com/NVIDIA/TensorRT-LLM/issues/7142
135140
kv_cache_config: KvCacheConfig = Field(

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@
1616
from typing import Dict, List, Optional, Tuple
1717

1818
import torch
19+
import torch.nn as nn
1920
from strenum import StrEnum
21+
from torch import fx
2022
from torch._prims_common import DeviceLikeType
2123

2224
from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures
23-
from tensorrt_llm._torch.pyexecutor._util import _create_kv_cache_manager, get_kv_cache_manager_cls
25+
from tensorrt_llm._torch.pyexecutor._util import (
26+
_create_kv_cache_manager,
27+
get_decoding_mode,
28+
get_kv_cache_manager_cls,
29+
)
2430
from tensorrt_llm._torch.pyexecutor.guided_decoder import GuidedDecoder
2531
from tensorrt_llm._torch.pyexecutor.llm_request import get_draft_token_length
2632
from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config
@@ -30,6 +36,7 @@
3036
from tensorrt_llm.llmapi.llm_args import (
3137
ContextChunkingPolicy,
3238
LoadFormat,
39+
SamplerType,
3340
SpeculativeConfig,
3441
TorchLlmArgs,
3542
)
@@ -42,7 +49,7 @@
4249
from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine
4350
from ...pyexecutor.py_executor import PyExecutor
4451
from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager, ResourceManagerType
45-
from ...pyexecutor.sampler import TorchSampler
52+
from ...pyexecutor.sampler import TorchSampler, TRTLLMSampler
4653
from ...pyexecutor.scheduler import (
4754
BindCapacityScheduler,
4855
BindMicroBatchScheduler,
@@ -53,6 +60,7 @@
5360
from ..distributed import common as dist
5461
from ..llm_args import LlmArgs
5562
from ..transform.optimizer import InferenceOptimizer
63+
from ..utils._graph import named_graphmodules
5664
from ..utils.logger import ad_logger
5765
from .interface import CachedSequenceInterface, GetInferenceModel
5866

@@ -283,9 +291,9 @@ def __init__(
283291
self.llm_args.batch_wait_timeout_iters = 0
284292
self.llm_args.batch_wait_max_tokens_ratio = 0.0
285293
self.llm_args.max_num_tokens = seq_info.max_num_tokens
294+
self.llm_args.max_seq_len = seq_info.max_seq_len
286295
self.iter_counter = 0
287296
self.iter_states = {}
288-
self.llm_args.max_seq_len = seq_info.max_seq_len
289297

290298
# NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor...
291299
self.max_beam_width = max_beam_width
@@ -487,6 +495,9 @@ def _compute_logits(self) -> List[torch.Tensor]:
487495
# run the model
488496
logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0]
489497

498+
# TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless.
499+
logits = logits.float()
500+
490501
# return a list of tensors
491502
return self.cache_seq_interface.info.unnest_sequences(logits)
492503

@@ -574,6 +585,91 @@ def create_draft_model_engine_maybe(
574585
return draft_model_engine
575586

576587

588+
class TRTLLMSamplerModelConfig:
589+
def __init__(self, vocab_size_padded: int):
590+
self.config = SimpleNamespace()
591+
self.config.vocab_size = vocab_size_padded
592+
593+
# Initialized to dummy values as they are not used in the C++ code underlying TRTLLMSampler.
594+
self.config.num_hidden_layers = 42
595+
self.config.hidden_size = 42
596+
self.config.num_attention_heads = 42
597+
598+
599+
def get_model_dtype(model: nn.Module) -> torch.dtype:
600+
# Find the graph module (handle potentially compiled/wrapped models)
601+
602+
try:
603+
if isinstance(model, fx.GraphModule):
604+
graph_module = model
605+
else:
606+
for _, gm in named_graphmodules(model):
607+
graph_module = gm
608+
break
609+
610+
# Get the output node
611+
output_nodes = graph_module.graph.find_nodes(op="output")
612+
613+
output_node = output_nodes[0]
614+
615+
if hasattr(output_node, "meta") and "val" in output_node.meta:
616+
val = output_node.meta["val"]
617+
if hasattr(val, "dtype"):
618+
return val.dtype
619+
if isinstance(val, (tuple, list)) and val:
620+
first = val[0]
621+
if hasattr(first, "dtype"):
622+
return first.dtype
623+
624+
except Exception:
625+
raise RuntimeError("Failed to get the model dtype from the graph.")
626+
627+
raise RuntimeError("Failed to get the model dtype from the graph.")
628+
629+
630+
def instantiate_sampler(
631+
ad_config: LlmArgs,
632+
max_num_sequences: int,
633+
max_draft_len: int,
634+
max_total_draft_tokens: int,
635+
dist_mapping: Mapping,
636+
engine: ADEngine,
637+
):
638+
if ad_config.sampler_type == SamplerType.TorchSampler:
639+
# search sampler with speculative decoding
640+
sampler_args = TorchSampler.Args(
641+
max_seq_len=ad_config.max_seq_len,
642+
max_draft_len=max_draft_len,
643+
max_total_draft_tokens=max_total_draft_tokens,
644+
max_num_sequences=max_num_sequences,
645+
max_beam_width=ad_config.max_beam_width,
646+
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
647+
)
648+
sampler = TorchSampler(sampler_args)
649+
650+
elif ad_config.sampler_type == SamplerType.TRTLLMSampler:
651+
vocab_size_padded: int = engine.cache_seq_interface.info.vocab_size_padded
652+
sampler_model_config = TRTLLMSamplerModelConfig(vocab_size_padded)
653+
decoding_mode = get_decoding_mode(ad_config.decoding_config, ad_config.max_beam_width)
654+
model_dtype: torch.dtype = get_model_dtype(engine.model)
655+
sampler = TRTLLMSampler(
656+
model=sampler_model_config,
657+
model_dtype=model_dtype,
658+
mapping=dist_mapping,
659+
decoding_mode=decoding_mode,
660+
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
661+
max_seq_len=ad_config.max_seq_len,
662+
max_batch_size=ad_config.max_batch_size,
663+
max_beam_width=ad_config.max_beam_width,
664+
decoding_config=ad_config.decoding_config,
665+
kv_cache_config=ad_config.kv_cache_config,
666+
)
667+
else:
668+
raise ValueError(f"Sampler type {ad_config.sampler_type} is not supported.")
669+
670+
return sampler
671+
672+
577673
def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[TokenizerBase] = None):
578674
"""Create an AutoDeploy executor from the given configuration and tokenizer.
579675
The tokenizer is required for guided decoding.
@@ -695,23 +791,21 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer
695791
)
696792
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)
697793

698-
# search sampler with speculative decoding
699-
sampler_args = TorchSampler.Args(
700-
max_seq_len=ad_config.max_seq_len,
794+
vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded
795+
sampler = instantiate_sampler(
796+
ad_config=ad_config,
797+
max_num_sequences=max_num_sequences,
701798
max_draft_len=max_draft_len,
702799
max_total_draft_tokens=max_total_draft_tokens,
703-
max_num_sequences=max_num_sequences,
704-
max_beam_width=ad_config.max_beam_width,
705-
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
800+
dist_mapping=dist_mapping,
801+
engine=engine,
706802
)
707-
sampler = TorchSampler(sampler_args)
708803

709-
# Guided (structured) decoding.
804+
# Guided (istructured) decoding.
710805
guided_decoder = None
711806
if (
712807
(guided_decoding_backend := ad_config.guided_decoding_backend) is not None
713808
) and dist_mapping.is_last_pp_rank():
714-
vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded
715809
if vocab_size_padded is None:
716810
raise RuntimeError(
717811
"Could not determine the vocabulary size. Required for guided decoding."
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from _model_test_utils import get_small_model_config
17+
from build_and_run_ad import ExperimentConfig, main
18+
19+
from tensorrt_llm.llmapi.llm_args import SamplerType
20+
21+
22+
def test_ad_trtllm_sampler_smoke():
23+
"""Test TRTLLMSampler in AutoDeploy smoke test."""
24+
# Get small model config
25+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
26+
experiment_config = get_small_model_config(model_id)
27+
28+
# Configure for TRTLLMSampler
29+
experiment_config["args"]["runtime"] = "trtllm"
30+
experiment_config["args"]["world_size"] = 1
31+
experiment_config["args"]["sampler_type"] = SamplerType.TRTLLMSampler
32+
33+
# Setup simple prompt
34+
experiment_config["prompt"]["batch_size"] = 1
35+
experiment_config["prompt"]["queries"] = {"prompt": "What is the capital of France?"}
36+
experiment_config["prompt"]["sp_kwargs"] = {
37+
"max_tokens": 10,
38+
"temperature": 1.0,
39+
"top_k": 1,
40+
}
41+
42+
print(f"Experiment config: {experiment_config}")
43+
cfg = ExperimentConfig(**experiment_config)
44+
45+
print("Running smoke test with TRTLLMSampler...")
46+
results = main(cfg)
47+
48+
# Basic assertion that we got some output
49+
prompts_and_outputs = results["prompts_and_outputs"]
50+
assert len(prompts_and_outputs) == 1
51+
assert len(prompts_and_outputs[0][1]) > 0

0 commit comments

Comments
 (0)