|
16 | 16 | from typing import Dict, List, Optional, Tuple |
17 | 17 |
|
18 | 18 | import torch |
| 19 | +import torch.nn as nn |
19 | 20 | from strenum import StrEnum |
| 21 | +from torch import fx |
20 | 22 | from torch._prims_common import DeviceLikeType |
21 | 23 |
|
22 | 24 | from tensorrt_llm._torch.attention_backend.interface import AttentionRuntimeFeatures |
23 | | -from tensorrt_llm._torch.pyexecutor._util import _create_kv_cache_manager, get_kv_cache_manager_cls |
| 25 | +from tensorrt_llm._torch.pyexecutor._util import ( |
| 26 | + _create_kv_cache_manager, |
| 27 | + get_decoding_mode, |
| 28 | + get_kv_cache_manager_cls, |
| 29 | +) |
24 | 30 | from tensorrt_llm._torch.pyexecutor.guided_decoder import GuidedDecoder |
25 | 31 | from tensorrt_llm._torch.pyexecutor.llm_request import get_draft_token_length |
26 | 32 | from tensorrt_llm._torch.pyexecutor.py_executor_creator import get_guided_decoding_config |
|
30 | 36 | from tensorrt_llm.llmapi.llm_args import ( |
31 | 37 | ContextChunkingPolicy, |
32 | 38 | LoadFormat, |
| 39 | + SamplerType, |
33 | 40 | SpeculativeConfig, |
34 | 41 | TorchLlmArgs, |
35 | 42 | ) |
|
42 | 49 | from ...pyexecutor.model_engine import ModelEngine, PyTorchModelEngine |
43 | 50 | from ...pyexecutor.py_executor import PyExecutor |
44 | 51 | from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager, ResourceManagerType |
45 | | -from ...pyexecutor.sampler import TorchSampler |
| 52 | +from ...pyexecutor.sampler import TorchSampler, TRTLLMSampler |
46 | 53 | from ...pyexecutor.scheduler import ( |
47 | 54 | BindCapacityScheduler, |
48 | 55 | BindMicroBatchScheduler, |
|
53 | 60 | from ..distributed import common as dist |
54 | 61 | from ..llm_args import LlmArgs |
55 | 62 | from ..transform.optimizer import InferenceOptimizer |
| 63 | +from ..utils._graph import named_graphmodules |
56 | 64 | from ..utils.logger import ad_logger |
57 | 65 | from .interface import CachedSequenceInterface, GetInferenceModel |
58 | 66 |
|
@@ -283,9 +291,9 @@ def __init__( |
283 | 291 | self.llm_args.batch_wait_timeout_iters = 0 |
284 | 292 | self.llm_args.batch_wait_max_tokens_ratio = 0.0 |
285 | 293 | self.llm_args.max_num_tokens = seq_info.max_num_tokens |
| 294 | + self.llm_args.max_seq_len = seq_info.max_seq_len |
286 | 295 | self.iter_counter = 0 |
287 | 296 | self.iter_states = {} |
288 | | - self.llm_args.max_seq_len = seq_info.max_seq_len |
289 | 297 |
|
290 | 298 | # NOTE (lucaslie): not a declared base member in the base class; required by PyExecutor... |
291 | 299 | self.max_beam_width = max_beam_width |
@@ -487,6 +495,9 @@ def _compute_logits(self) -> List[torch.Tensor]: |
487 | 495 | # run the model |
488 | 496 | logits: torch.Tensor = self.model(**self.cache_seq_interface.named_args)[0] |
489 | 497 |
|
| 498 | + # TRTLLMSampler expects float32 logits. PyTorchModelEngine always casts to float32 regardless. |
| 499 | + logits = logits.float() |
| 500 | + |
490 | 501 | # return a list of tensors |
491 | 502 | return self.cache_seq_interface.info.unnest_sequences(logits) |
492 | 503 |
|
@@ -574,6 +585,91 @@ def create_draft_model_engine_maybe( |
574 | 585 | return draft_model_engine |
575 | 586 |
|
576 | 587 |
|
| 588 | +class TRTLLMSamplerModelConfig: |
| 589 | + def __init__(self, vocab_size_padded: int): |
| 590 | + self.config = SimpleNamespace() |
| 591 | + self.config.vocab_size = vocab_size_padded |
| 592 | + |
| 593 | + # Initialized to dummy values as they are not used in the C++ code underlying TRTLLMSampler. |
| 594 | + self.config.num_hidden_layers = 42 |
| 595 | + self.config.hidden_size = 42 |
| 596 | + self.config.num_attention_heads = 42 |
| 597 | + |
| 598 | + |
| 599 | +def get_model_dtype(model: nn.Module) -> torch.dtype: |
| 600 | + # Find the graph module (handle potentially compiled/wrapped models) |
| 601 | + |
| 602 | + try: |
| 603 | + if isinstance(model, fx.GraphModule): |
| 604 | + graph_module = model |
| 605 | + else: |
| 606 | + for _, gm in named_graphmodules(model): |
| 607 | + graph_module = gm |
| 608 | + break |
| 609 | + |
| 610 | + # Get the output node |
| 611 | + output_nodes = graph_module.graph.find_nodes(op="output") |
| 612 | + |
| 613 | + output_node = output_nodes[0] |
| 614 | + |
| 615 | + if hasattr(output_node, "meta") and "val" in output_node.meta: |
| 616 | + val = output_node.meta["val"] |
| 617 | + if hasattr(val, "dtype"): |
| 618 | + return val.dtype |
| 619 | + if isinstance(val, (tuple, list)) and val: |
| 620 | + first = val[0] |
| 621 | + if hasattr(first, "dtype"): |
| 622 | + return first.dtype |
| 623 | + |
| 624 | + except Exception: |
| 625 | + raise RuntimeError("Failed to get the model dtype from the graph.") |
| 626 | + |
| 627 | + raise RuntimeError("Failed to get the model dtype from the graph.") |
| 628 | + |
| 629 | + |
| 630 | +def instantiate_sampler( |
| 631 | + ad_config: LlmArgs, |
| 632 | + max_num_sequences: int, |
| 633 | + max_draft_len: int, |
| 634 | + max_total_draft_tokens: int, |
| 635 | + dist_mapping: Mapping, |
| 636 | + engine: ADEngine, |
| 637 | +): |
| 638 | + if ad_config.sampler_type == SamplerType.TorchSampler: |
| 639 | + # search sampler with speculative decoding |
| 640 | + sampler_args = TorchSampler.Args( |
| 641 | + max_seq_len=ad_config.max_seq_len, |
| 642 | + max_draft_len=max_draft_len, |
| 643 | + max_total_draft_tokens=max_total_draft_tokens, |
| 644 | + max_num_sequences=max_num_sequences, |
| 645 | + max_beam_width=ad_config.max_beam_width, |
| 646 | + disable_overlap_scheduler=ad_config.disable_overlap_scheduler, |
| 647 | + ) |
| 648 | + sampler = TorchSampler(sampler_args) |
| 649 | + |
| 650 | + elif ad_config.sampler_type == SamplerType.TRTLLMSampler: |
| 651 | + vocab_size_padded: int = engine.cache_seq_interface.info.vocab_size_padded |
| 652 | + sampler_model_config = TRTLLMSamplerModelConfig(vocab_size_padded) |
| 653 | + decoding_mode = get_decoding_mode(ad_config.decoding_config, ad_config.max_beam_width) |
| 654 | + model_dtype: torch.dtype = get_model_dtype(engine.model) |
| 655 | + sampler = TRTLLMSampler( |
| 656 | + model=sampler_model_config, |
| 657 | + model_dtype=model_dtype, |
| 658 | + mapping=dist_mapping, |
| 659 | + decoding_mode=decoding_mode, |
| 660 | + disable_overlap_scheduler=ad_config.disable_overlap_scheduler, |
| 661 | + max_seq_len=ad_config.max_seq_len, |
| 662 | + max_batch_size=ad_config.max_batch_size, |
| 663 | + max_beam_width=ad_config.max_beam_width, |
| 664 | + decoding_config=ad_config.decoding_config, |
| 665 | + kv_cache_config=ad_config.kv_cache_config, |
| 666 | + ) |
| 667 | + else: |
| 668 | + raise ValueError(f"Sampler type {ad_config.sampler_type} is not supported.") |
| 669 | + |
| 670 | + return sampler |
| 671 | + |
| 672 | + |
577 | 673 | def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[TokenizerBase] = None): |
578 | 674 | """Create an AutoDeploy executor from the given configuration and tokenizer. |
579 | 675 | The tokenizer is required for guided decoding. |
@@ -695,23 +791,21 @@ def create_autodeploy_executor(ad_config: LlmArgs, tokenizer: Optional[Tokenizer |
695 | 791 | ) |
696 | 792 | scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler) |
697 | 793 |
|
698 | | - # search sampler with speculative decoding |
699 | | - sampler_args = TorchSampler.Args( |
700 | | - max_seq_len=ad_config.max_seq_len, |
| 794 | + vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded |
| 795 | + sampler = instantiate_sampler( |
| 796 | + ad_config=ad_config, |
| 797 | + max_num_sequences=max_num_sequences, |
701 | 798 | max_draft_len=max_draft_len, |
702 | 799 | max_total_draft_tokens=max_total_draft_tokens, |
703 | | - max_num_sequences=max_num_sequences, |
704 | | - max_beam_width=ad_config.max_beam_width, |
705 | | - disable_overlap_scheduler=ad_config.disable_overlap_scheduler, |
| 800 | + dist_mapping=dist_mapping, |
| 801 | + engine=engine, |
706 | 802 | ) |
707 | | - sampler = TorchSampler(sampler_args) |
708 | 803 |
|
709 | | - # Guided (structured) decoding. |
| 804 | + # Guided (istructured) decoding. |
710 | 805 | guided_decoder = None |
711 | 806 | if ( |
712 | 807 | (guided_decoding_backend := ad_config.guided_decoding_backend) is not None |
713 | 808 | ) and dist_mapping.is_last_pp_rank(): |
714 | | - vocab_size_padded = engine.cache_seq_interface.info.vocab_size_padded |
715 | 809 | if vocab_size_padded is None: |
716 | 810 | raise RuntimeError( |
717 | 811 | "Could not determine the vocabulary size. Required for guided decoding." |
|
0 commit comments