Skip to content

Commit 1c7b4be

Browse files
committed
[TRTLLM-6756][feat] Update beam search in TorchSampler to cover more features
- Add metadata object to grouped_request to pass additional data, which is not part of the SamplingStrategy definition. - Add several buffer to TorchSampler Store for beam search features, which are only allocated when beam search is used - Add support for beam search with streaming enabled - Beam search no longer requires all beams to finish at the same iteration. - gather_generation_logits can now be used together with beam search. - Logprob generation is now possible with beam search enabled. Top-k logprobs is not supported - Updated test_beam_search.py to also cover TorchSampler - General changes for formatting and readability Signed-off-by: Stefan Niebler <[email protected]>
1 parent 1162819 commit 1c7b4be

File tree

4 files changed

+552
-232
lines changed

4 files changed

+552
-232
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -811,10 +811,15 @@ def create_py_executor_instance(
811811
peft_cache_config=peft_cache_config)
812812

813813

814-
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
815-
max_batch_size: int,
816-
speculative_config: SpeculativeConfig,
817-
max_beam_width: int, use_overlap_scheduler: bool):
814+
def create_torch_sampler_args(
815+
mapping: Mapping,
816+
*,
817+
max_seq_len: int,
818+
max_batch_size: int,
819+
speculative_config: SpeculativeConfig,
820+
max_beam_width: int,
821+
use_overlap_scheduler: bool,
822+
):
818823
max_num_sequences = max_batch_size * mapping.pp_size
819824
max_draft_len = (0 if speculative_config is None else
820825
speculative_config.max_draft_len)
@@ -829,13 +834,18 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
829834
use_overlap_scheduler=use_overlap_scheduler)
830835

831836

832-
def instantiate_sampler(engine: PyTorchModelEngine,
833-
pytorch_backend_config: PyTorchConfig, mapping: Mapping,
834-
max_batch_size: int, max_beam_width: int,
835-
max_seq_len: int, mm_encoder_only: bool,
836-
speculative_config: SpeculativeConfig,
837-
decoding_config: trtllm.DecodingConfig,
838-
kv_cache_config: KvCacheConfig):
837+
def instantiate_sampler(
838+
engine: PyTorchModelEngine,
839+
pytorch_backend_config: PyTorchConfig,
840+
mapping: Mapping,
841+
max_batch_size: int,
842+
max_beam_width: int,
843+
max_seq_len: int,
844+
mm_encoder_only: bool,
845+
speculative_config: SpeculativeConfig,
846+
decoding_config: trtllm.DecodingConfig,
847+
kv_cache_config: KvCacheConfig,
848+
):
839849
sampler_args = create_torch_sampler_args(
840850
mapping,
841851
max_seq_len=engine.max_seq_len,

0 commit comments

Comments
 (0)