Skip to content

Commit 02f569a

Browse files
committed
[TRTLLM-6756][feat] Enhance TorchSampler to support beam search sampling.
- Added BeamSearchArgs class and updated methods to handle beam search logic, including cache indirection updates and beam score management. - Modified create_torch_sampler_args to include use_overlap_scheduler parameter. - Updated sampling strategy to accommodate beam search requests. Signed-off-by: Stefan Niebler <[email protected]>
1 parent d0663e1 commit 02f569a

File tree

2 files changed

+527
-66
lines changed

2 files changed

+527
-66
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def create_py_executor_instance(
820820
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
821821
max_batch_size: int,
822822
speculative_config: SpeculativeConfig,
823-
max_beam_width: int):
823+
max_beam_width: int, use_overlap_scheduler: bool):
824824
max_num_sequences = max_batch_size * mapping.pp_size
825825
max_draft_len = (0 if speculative_config is None else
826826
speculative_config.max_draft_len)
@@ -832,13 +832,12 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
832832
else:
833833
max_total_draft_tokens = max_draft_len
834834

835-
return TorchSampler.Args(
836-
max_seq_len=max_seq_len,
837-
max_draft_len=max_draft_len,
838-
max_total_draft_tokens=max_total_draft_tokens,
839-
max_num_sequences=max_num_sequences,
840-
max_beam_width=max_beam_width,
841-
)
835+
return TorchSampler.Args(max_seq_len=max_seq_len,
836+
max_draft_len=max_draft_len,
837+
max_total_draft_tokens=max_total_draft_tokens,
838+
max_num_sequences=max_num_sequences,
839+
max_beam_width=max_beam_width,
840+
use_overlap_scheduler=use_overlap_scheduler)
842841

843842

844843
def instantiate_sampler(engine: PyTorchModelEngine,
@@ -853,7 +852,9 @@ def instantiate_sampler(engine: PyTorchModelEngine,
853852
max_seq_len=engine.max_seq_len,
854853
max_batch_size=max_batch_size,
855854
speculative_config=speculative_config,
856-
max_beam_width=max_beam_width)
855+
max_beam_width=max_beam_width,
856+
use_overlap_scheduler=not pytorch_backend_config.
857+
disable_overlap_scheduler)
857858
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
858859
max_beam_width=max_beam_width)
859860
if mapping.cp_config.get('cp_type') == CpType.STAR:

0 commit comments

Comments
 (0)