Skip to content

Commit 6311856

Browse files
committed
init
Signed-off-by: QI JUN <[email protected]>
1 parent e44f768 commit 6311856

File tree

1 file changed

+106
-7
lines changed

1 file changed

+106
-7
lines changed

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 106 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
# yapf: disable
2020
from ..bindings.executor import (BatchingType, CapacitySchedulerPolicy,
2121
ContextChunkingPolicy, ExecutorConfig,
22-
KvCacheRetentionConfig, SchedulerConfig)
22+
GuidedDecodingConfig, KvCacheRetentionConfig,
23+
PeftCacheConfig, SchedulerConfig)
2324
# yapf: enable
24-
from ..builder import BuildConfig, Engine, build
25+
from ..builder import BuildConfig, Engine, EngineConfig, build
2526
from ..llmapi.llm_args import TrtLlmArgs
2627
from ..logger import logger
2728
from ..mapping import Mapping
@@ -30,15 +31,16 @@
3031
from ..module import Module
3132
from .build_cache import (BuildCache, BuildCacheConfig, CachedStage,
3233
get_build_cache_config_from_env)
33-
from .llm_args import (CalibConfig, DraftTargetDecodingConfig,
34+
from .llm_args import (BaseLlmArgs, CalibConfig, DraftTargetDecodingConfig,
3435
EagleDecodingConfig, KvCacheConfig, LlmArgs,
3536
LookaheadDecodingConfig, MedusaDecodingConfig,
36-
MTPDecodingConfig, NGramDecodingConfig, _ModelFormatKind,
37-
_ModelWrapper, _ParallelConfig, get_model_format,
38-
update_llm_args_with_extra_dict,
37+
MTPDecodingConfig, NGramDecodingConfig, PybindMirror,
38+
_ModelFormatKind, _ModelWrapper, _ParallelConfig,
39+
get_model_format, update_llm_args_with_extra_dict,
3940
update_llm_args_with_extra_options)
4041
from .mpi_session import MPINodeState, MpiSession
41-
from .tokenizer import TransformersTokenizer, load_hf_tokenizer
42+
from .tokenizer import (TransformersTokenizer, _xgrammar_tokenizer_info,
43+
load_hf_tokenizer)
4244
# TODO[chunweiy]: move the following symbols back to utils scope, and remove the following import
4345
from .utils import (download_hf_model, download_hf_pretrained_config,
4446
enable_llm_debug, get_directory_size_in_gb, print_colored,
@@ -855,6 +857,103 @@ class LlmBuildStats:
855857
build_steps_info: List[Tuple[str, float]] = field(default_factory=list)
856858

857859

860+
def llm_args_to_executor_config(args: BaseLlmArgs, tokenizer) -> ExecutorConfig:
861+
max_batch_size = args.max_batch_size
862+
max_num_tokens = args.max_num_tokens
863+
max_seq_len = args.max_seq_len
864+
865+
build_config = args.build_config if isinstance(
866+
args, TrtLlmArgs) else BuildConfig()
867+
868+
max_batch_size = max_batch_size or build_config.max_batch_size
869+
max_num_tokens = max_num_tokens or build_config.max_num_tokens
870+
max_seq_len = max_seq_len or build_config.max_seq_len
871+
872+
executor_config = ExecutorConfig(
873+
max_beam_width=args.max_beam_width,
874+
scheduler_config=PybindMirror.maybe_to_pybind(args.scheduler_config),
875+
batching_type=PybindMirror.maybe_to_pybind(args.batching_type)
876+
or BatchingType.INFLIGHT,
877+
max_batch_size=max_batch_size,
878+
max_num_tokens=max_num_tokens,
879+
gather_generation_logits=args.gather_generation_logits)
880+
if args.backend is None:
881+
# also set executor_config.max_seq_len in TRT workflow, to deduce default max_tokens
882+
if max_seq_len is not None:
883+
executor_config.max_seq_len = max_seq_len
884+
else:
885+
engine_config = EngineConfig.from_json_file(args.model /
886+
"config.json")
887+
executor_config.max_seq_len = engine_config.build_config.max_seq_len
888+
889+
if args.kv_cache_config is not None:
890+
executor_config.kv_cache_config = PybindMirror.maybe_to_pybind(
891+
args.kv_cache_config)
892+
if os.getenv("FORCE_DETERMINISTIC", "0") == "1":
893+
# Disable KV cache reuse for deterministic mode
894+
executor_config.kv_cache_config.enable_block_reuse = False
895+
executor_config.kv_cache_config.enable_partial_reuse = False
896+
897+
if args.peft_cache_config is not None:
898+
executor_config.peft_cache_config = PybindMirror.maybe_to_pybind(
899+
args.peft_cache_config)
900+
elif isinstance(args,
901+
TrtLlmArgs) and args.build_config.plugin_config.lora_plugin:
902+
engine_config = EngineConfig.from_json_file(args.model / "config.json")
903+
lora_config = engine_config.build_config.lora_config
904+
max_lora_rank = lora_config.max_lora_rank
905+
num_lora_modules = engine_config.pretrained_config.num_hidden_layers * \
906+
len(lora_config.lora_target_modules + lora_config.missing_qkv_modules)
907+
executor_config.peft_cache_config = PeftCacheConfig(
908+
num_device_module_layer=max_lora_rank * num_lora_modules *
909+
args.max_loras,
910+
num_host_module_layer=max_lora_rank * num_lora_modules *
911+
args.max_cpu_loras,
912+
)
913+
if args.decoding_config is not None:
914+
executor_config.decoding_config = args.decoding_config
915+
916+
if args.guided_decoding_backend == 'xgrammar':
917+
executor_config.guided_decoding_config = GuidedDecodingConfig(
918+
backend=GuidedDecodingConfig.GuidedDecodingBackend.XGRAMMAR,
919+
**_xgrammar_tokenizer_info(tokenizer))
920+
elif args.guided_decoding_backend is not None:
921+
raise ValueError(
922+
f"Unrecognized guided decoding backend {args.guided_decoding_backend}"
923+
)
924+
925+
executor_config.normalize_log_probs = args.normalize_log_probs
926+
executor_config.enable_chunked_context = args.enable_chunked_prefill
927+
executor_config.max_beam_width = args.max_beam_width or args.build_config.max_beam_width
928+
if isinstance(
929+
args,
930+
TrtLlmArgs) and args.extended_runtime_perf_knob_config is not None:
931+
executor_config.extended_runtime_perf_knob_config = PybindMirror.maybe_to_pybind(
932+
args.extended_runtime_perf_knob_config)
933+
934+
if args.cache_transceiver_config is not None:
935+
executor_config.cache_transceiver_config = PybindMirror.maybe_to_pybind(
936+
args.cache_transceiver_config)
937+
938+
from tensorrt_llm._torch.pyexecutor.config import update_executor_config
939+
update_executor_config(
940+
executor_config,
941+
backend=args.backend,
942+
pytorch_backend_config=args.get_pytorch_backend_config()
943+
if args.backend in ["pytorch", "_autodeploy"] else None,
944+
mapping=args.parallel_config.to_mapping(),
945+
build_config=args.build_config
946+
if isinstance(args, TrtLlmArgs) else None,
947+
speculative_config=args.speculative_config,
948+
hf_model_dir=self._hf_model_dir,
949+
trt_engine_dir=self._engine_dir,
950+
max_input_len=args.max_input_len,
951+
max_seq_len=max_seq_len)
952+
953+
executor_config.llm_parallel_config = args.parallel_config
954+
return executor_config
955+
956+
858957
__all__ = [
859958
'LlmArgs',
860959
'LlmBuildStats',

0 commit comments

Comments
 (0)