|
19 | 19 | # yapf: disable |
20 | 20 | from ..bindings.executor import (BatchingType, CapacitySchedulerPolicy, |
21 | 21 | ContextChunkingPolicy, ExecutorConfig, |
22 | | - KvCacheRetentionConfig, SchedulerConfig) |
| 22 | + GuidedDecodingConfig, KvCacheRetentionConfig, |
| 23 | + PeftCacheConfig, SchedulerConfig) |
23 | 24 | # yapf: enable |
24 | | -from ..builder import BuildConfig, Engine, build |
| 25 | +from ..builder import BuildConfig, Engine, EngineConfig, build |
25 | 26 | from ..llmapi.llm_args import TrtLlmArgs |
26 | 27 | from ..logger import logger |
27 | 28 | from ..mapping import Mapping |
|
30 | 31 | from ..module import Module |
31 | 32 | from .build_cache import (BuildCache, BuildCacheConfig, CachedStage, |
32 | 33 | get_build_cache_config_from_env) |
33 | | -from .llm_args import (CalibConfig, DraftTargetDecodingConfig, |
| 34 | +from .llm_args import (BaseLlmArgs, CalibConfig, DraftTargetDecodingConfig, |
34 | 35 | EagleDecodingConfig, KvCacheConfig, LlmArgs, |
35 | 36 | LookaheadDecodingConfig, MedusaDecodingConfig, |
36 | | - MTPDecodingConfig, NGramDecodingConfig, _ModelFormatKind, |
37 | | - _ModelWrapper, _ParallelConfig, get_model_format, |
38 | | - update_llm_args_with_extra_dict, |
| 37 | + MTPDecodingConfig, NGramDecodingConfig, PybindMirror, |
| 38 | + _ModelFormatKind, _ModelWrapper, _ParallelConfig, |
| 39 | + get_model_format, update_llm_args_with_extra_dict, |
39 | 40 | update_llm_args_with_extra_options) |
40 | 41 | from .mpi_session import MPINodeState, MpiSession |
41 | | -from .tokenizer import TransformersTokenizer, load_hf_tokenizer |
| 42 | +from .tokenizer import (TransformersTokenizer, _xgrammar_tokenizer_info, |
| 43 | + load_hf_tokenizer) |
42 | 44 | # TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import |
43 | 45 | from .utils import (download_hf_model, download_hf_pretrained_config, |
44 | 46 | enable_llm_debug, get_directory_size_in_gb, print_colored, |
@@ -855,6 +857,103 @@ class LlmBuildStats: |
855 | 857 | build_steps_info: List[Tuple[str, float]] = field(default_factory=list) |
856 | 858 |
|
857 | 859 |
|
| 860 | +def llm_args_to_executor_config(args: BaseLlmArgs, tokenizer) -> ExecutorConfig: |
| 861 | + max_batch_size = args.max_batch_size |
| 862 | + max_num_tokens = args.max_num_tokens |
| 863 | + max_seq_len = args.max_seq_len |
| 864 | + |
| 865 | + build_config = args.build_config if isinstance( |
| 866 | + args, TrtLlmArgs) else BuildConfig() |
| 867 | + |
| 868 | + max_batch_size = max_batch_size or build_config.max_batch_size |
| 869 | + max_num_tokens = max_num_tokens or build_config.max_num_tokens |
| 870 | + max_seq_len = max_seq_len or build_config.max_seq_len |
| 871 | + |
| 872 | + executor_config = ExecutorConfig( |
| 873 | + max_beam_width=args.max_beam_width, |
| 874 | + scheduler_config=PybindMirror.maybe_to_pybind(args.scheduler_config), |
| 875 | + batching_type=PybindMirror.maybe_to_pybind(args.batching_type) |
| 876 | + or BatchingType.INFLIGHT, |
| 877 | + max_batch_size=max_batch_size, |
| 878 | + max_num_tokens=max_num_tokens, |
| 879 | + gather_generation_logits=args.gather_generation_logits) |
| 880 | + if args.backend is None: |
| 881 | + # also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokens |
| 882 | + if max_seq_len is not None: |
| 883 | + executor_config.max_seq_len = max_seq_len |
| 884 | + else: |
| 885 | + engine_config = EngineConfig.from_json_file(args.model / |
| 886 | + "config.json") |
| 887 | + executor_config.max_seq_len = engine_config.build_config.max_seq_len |
| 888 | + |
| 889 | + if args.kv_cache_config is not None: |
| 890 | + executor_config.kv_cache_config = PybindMirror.maybe_to_pybind( |
| 891 | + args.kv_cache_config) |
| 892 | + if os.getenv("FORCE_DETERMINISTIC", "0") == "1": |
| 893 | + # Disable KV cache reuse for deterministic mode |
| 894 | + executor_config.kv_cache_config.enable_block_reuse = False |
| 895 | + executor_config.kv_cache_config.enable_partial_reuse = False |
| 896 | + |
| 897 | + if args.peft_cache_config is not None: |
| 898 | + executor_config.peft_cache_config = PybindMirror.maybe_to_pybind( |
| 899 | + args.peft_cache_config) |
| 900 | + elif isinstance(args, |
| 901 | + TrtLlmArgs) and args.build_config.plugin_config.lora_plugin: |
| 902 | + engine_config = EngineConfig.from_json_file(args.model / "config.json") |
| 903 | + lora_config = engine_config.build_config.lora_config |
| 904 | + max_lora_rank = lora_config.max_lora_rank |
| 905 | + num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \ |
| 906 | + len(lora_config.lora_target_modules + lora_config.missing_qkv_modules) |
| 907 | + executor_config.peft_cache_config = PeftCacheConfig( |
| 908 | + num_device_module_layer=max_lora_rank * num_lora_modules * |
| 909 | + args.max_loras, |
| 910 | + num_host_module_layer=max_lora_rank * num_lora_modules * |
| 911 | + args.max_cpu_loras, |
| 912 | + ) |
| 913 | + if args.decoding_config is not None: |
| 914 | + executor_config.decoding_config = args.decoding_config |
| 915 | + |
| 916 | + if args.guided_decoding_backend == 'xgrammar': |
| 917 | + executor_config.guided_decoding_config = GuidedDecodingConfig( |
| 918 | + backend=GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR, |
| 919 | + **_xgrammar_tokenizer_info(tokenizer)) |
| 920 | + elif args.guided_decoding_backend is not None: |
| 921 | + raise ValueError( |
| 922 | + f"Unrecognized guided decoding backend {args.guided_decoding_backend}" |
| 923 | + ) |
| 924 | + |
| 925 | + executor_config.normalize_log_probs = args.normalize_log_probs |
| 926 | + executor_config.enable_chunked_context = args.enable_chunked_prefill |
| 927 | + executor_config.max_beam_width = args.max_beam_width or args.build_config.max_beam_width |
| 928 | + if isinstance( |
| 929 | + args, |
| 930 | + TrtLlmArgs) and args.extended_runtime_perf_knob_config is not None: |
| 931 | + executor_config.extended_runtime_perf_knob_config = PybindMirror.maybe_to_pybind( |
| 932 | + args.extended_runtime_perf_knob_config) |
| 933 | + |
| 934 | + if args.cache_transceiver_config is not None: |
| 935 | + executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind( |
| 936 | + args.cache_transceiver_config) |
| 937 | + |
| 938 | + from tensorrt_llm._torch.pyexecutor.config import update_executor_config |
| 939 | + update_executor_config( |
| 940 | + executor_config, |
| 941 | + backend=args.backend, |
| 942 | + pytorch_backend_config=args.get_pytorch_backend_config() |
| 943 | + if args.backend in ["pytorch", "_autodeploy"] else None, |
| 944 | + mapping=args.parallel_config.to_mapping(), |
| 945 | + build_config=args.build_config |
| 946 | + if isinstance(args, TrtLlmArgs) else None, |
| 947 | + speculative_config=args.speculative_config, |
| 948 | + hf_model_dir=self._hf_model_dir, |
| 949 | + trt_engine_dir=self._engine_dir, |
| 950 | + max_input_len=args.max_input_len, |
| 951 | + max_seq_len=max_seq_len) |
| 952 | + |
| 953 | + executor_config.llm_parallel_config = args.parallel_config |
| 954 | + return executor_config |
| 955 | + |
| 956 | + |
858 | 957 | __all__ = [ |
859 | 958 | 'LlmArgs', |
860 | 959 | 'LlmBuildStats', |
|
0 commit comments