Skip to content

Commit 0df758e

Browse files
authored
[TRTLLM-6650][feat] Enhance beam search support with CUDA graph integration (#6217)
Signed-off-by: Stefan Niebler <[email protected]>
1 parent ff72ca9 commit 0df758e

File tree

5 files changed

+63
-36
lines changed

5 files changed

+63
-36
lines changed

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ class AttentionMetadata:
135135
_num_ctx_tokens: int = field(init=False, default=0, repr=False)
136136
_num_tokens: int = field(init=False, default=0, repr=False)
137137

138+
# This buffer is currently only used for TrtllmAttentionMetadata.
139+
cache_indirection: Optional[torch.Tensor] = None
140+
138141
def __post_init__(self) -> None:
139142
if self.is_cross:
140143
assert self.cross is None or self.cross is self, "Cross attention metadata should not have sub metadata"

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,10 +517,9 @@ def is_nvfp4_output_kernel_available(
517517
class TrtllmAttentionMetadata(AttentionMetadata):
518518
workspace: Optional[torch.Tensor] = None
519519

520-
# TrtllmAttention needs to know the beam width and access to the cache indirection buffer,
520+
# TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
521521
# when beam search is enabled.
522522
beam_width: int = 1
523-
cache_indirection: Optional[torch.Tensor] = None
524523

525524
# TrtllmAttention needs to know the max sequence length.
526525
# Implemented as a property to support no cache mode.

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,6 @@ def __init__(
392392
self._cuda_graphs = {}
393393
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
394394
self._run_cuda_graphs = pytorch_backend_config.use_cuda_graph
395-
if self._run_cuda_graphs and self.max_beam_width > 1:
396-
raise NotImplementedError(
397-
"CUDA Graph + beam search is not implemented yet.")
398395

399396
self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled
400397

@@ -425,6 +422,17 @@ def __init__(
425422
self.lora_model_config: Optional[LoraModelConfig] = None
426423
self.cuda_graph_dummy_request = None
427424

425+
# Setup the local cache indirection buffer only once and reuse it.
426+
# This way it can also be used for CUDA graphs.
427+
if self.use_beam_search:
428+
self.cache_indirection_attention = torch.zeros(
429+
(self.batch_size, self.max_beam_width, self.max_seq_len +
430+
(0 if self._disable_overlap_scheduler else 1)),
431+
device="cuda",
432+
dtype=torch.int32)
433+
else:
434+
self.cache_indirection_attention = None
435+
428436
def set_lora_model_config(self, lora_target_modules: list[str],
429437
trtllm_modules_to_hf_modules: dict[str, str]):
430438
self.lora_model_config = LoraModelConfig(
@@ -444,6 +452,10 @@ def use_mrope(self):
444452
logger.info(f"Detected use_mrope: {use_mrope}")
445453
return use_mrope
446454

455+
@property
456+
def use_beam_search(self):
457+
return self.max_beam_width > 1
458+
447459
@contextmanager
448460
def set_warmup_flag(self):
449461
self.in_warmup = True
@@ -487,7 +499,9 @@ def warmup(self, resource_manager: ResourceManager) -> None:
487499
self.cuda_graph_dummy_request = None
488500

489501
def get_cuda_graph_warmup_request(batch_size):
490-
available_blocks = kv_cache_manager.get_num_free_blocks()
502+
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
503+
available_blocks = kv_cache_manager.get_num_free_blocks(
504+
) // self.max_beam_width
491505
if available_blocks >= batch_size:
492506
result = ScheduledRequests()
493507
result.context_requests = []
@@ -498,9 +512,10 @@ def get_cuda_graph_warmup_request(batch_size):
498512
is_gen=True,
499513
max_num_draft_tokens=self.max_draft_len,
500514
use_mrope=use_mrope,
501-
)
515+
max_beam_width=self.max_beam_width)
516+
# Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
502517
available_tokens = kv_cache_manager.get_num_available_tokens(
503-
self.max_draft_len)
518+
self.max_draft_len) // self.max_beam_width
504519

505520
# Add one dummy request with the maximum possible sequence length.
506521
# The sequence length is limited by both the max_seq_len and the number of available blocks.
@@ -511,7 +526,7 @@ def get_cuda_graph_warmup_request(batch_size):
511526
is_gen=True,
512527
max_num_draft_tokens=self.max_draft_len,
513528
use_mrope=use_mrope,
514-
)[0]
529+
max_beam_width=self.max_beam_width)[0]
515530
# Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
516531
# This batch contains both the longest request and the shortest requests,
517532
# it also contains the maximum number of requests and the maximum token number,
@@ -739,6 +754,7 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
739754
self.model.model_config.pretrained_config) and (
740755
self.attn_runtime_features.cache_reuse
741756
or self.attn_runtime_features.chunked_prefill)
757+
cache_indirection = self.cache_indirection_attention if self.attn_backend.Metadata is TrtllmAttentionMetadata else None
742758
if kv_cache_manager is None:
743759
return self.attn_backend.Metadata(
744760
max_num_requests=self.batch_size,
@@ -748,7 +764,8 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
748764
mapping=self.mapping,
749765
runtime_features=self.attn_runtime_features,
750766
enable_flash_mla=self.model.model_config.enable_flash_mla,
751-
enable_paged_context_mla=enable_paged_context_mla)
767+
enable_paged_context_mla=enable_paged_context_mla,
768+
cache_indirection=cache_indirection)
752769

753770
if self.attn_metadata is not None:
754771
# This assertion can be relaxed if needed: just create a new metadata
@@ -764,7 +781,9 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
764781
mapping=self.mapping,
765782
runtime_features=self.attn_runtime_features,
766783
enable_flash_mla=self.model.model_config.enable_flash_mla,
767-
enable_paged_context_mla=enable_paged_context_mla)
784+
enable_paged_context_mla=enable_paged_context_mla,
785+
cache_indirection=cache_indirection)
786+
768787
return self.attn_metadata
769788

770789
def _set_up_spec_metadata(
@@ -795,7 +814,8 @@ def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
795814
kv_cache_manager) -> int:
796815
can_run_cuda_graph = scheduled_requests.can_run_cuda_graph
797816
batch_size = scheduled_requests.batch_size
798-
new_batch_size = batch_size
817+
# The number of sequences in the batch is the number of prompts times the beam width.
818+
new_batch_size = batch_size * self.max_beam_width
799819
if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1:
800820
graph_batch_size = self.dist.tp_allgather(
801821
[can_run_cuda_graph, batch_size])
@@ -831,7 +851,8 @@ def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
831851
[MAX_UINT64 - 1],
832852
is_gen=True,
833853
max_num_draft_tokens=self.max_draft_len,
834-
use_mrope=self.use_mrope)[0]
854+
use_mrope=self.use_mrope,
855+
max_beam_width=self.max_beam_width)[0]
835856
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
836857

837858
scheduled_requests.generation_requests.extend(
@@ -903,19 +924,21 @@ def _maybe_get_cuda_graph(
903924
if batch_size not in self._cuda_graph_batch_sizes:
904925
return None
905926

927+
num_sequences_in_batch = batch_size * self.max_beam_width
906928
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
907-
batch_size, False, spec_max_draft_tokens)
929+
num_sequences_in_batch, False, spec_max_draft_tokens)
908930
assert attn_metadata.is_cuda_graph
909931

910932
if self.is_spec_decode:
911933
spec_metadata = self.spec_metadata.create_cuda_graph_metadata(
912-
batch_size)
934+
num_sequences_in_batch)
913935
spec_metadata.draft_tokens = self.draft_tokens_cuda
914936
else:
915937
spec_metadata = None
916938

917939
self._cuda_graphs[batch_size] = DecodingCUDAGraphRunner(
918-
batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope)
940+
num_sequences_in_batch, "cuda", attn_metadata, spec_metadata,
941+
self.use_mrope)
919942
return self._cuda_graphs[batch_size]
920943

921944
def __del__(self) -> None:
@@ -1439,16 +1462,16 @@ def previous_seq_slots_device():
14391462

14401463
num_generation_requests = len(scheduled_requests.generation_requests)
14411464
# Cache indirection is only used for beam search on generation requests
1442-
if self.max_beam_width > 1 and num_generation_requests > 0 and cache_indirection_buffer is not None:
1443-
cache_indirection_attention = torch.zeros_like(
1444-
cache_indirection_buffer)
1445-
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
1446-
cache_indirection_attention[:num_generation_requests].copy_(
1447-
cache_indirection_buffer[gen_request_seq_slots])
1448-
attn_metadata.cache_indirection = cache_indirection_attention
1449-
attn_metadata.beam_width = self.max_beam_width
1465+
if self.use_beam_search and num_generation_requests > 0:
1466+
# CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
1467+
is_cuda_graph_during_warmup = self.in_warmup and attn_metadata.is_cuda_graph
1468+
if cache_indirection_buffer is not None:
1469+
#Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
1470+
self.cache_indirection_attention[:num_generation_requests].copy_(
1471+
cache_indirection_buffer[gen_request_seq_slots])
1472+
if cache_indirection_buffer is not None or is_cuda_graph_during_warmup:
1473+
attn_metadata.beam_width = self.max_beam_width
14501474
else:
1451-
attn_metadata.cache_indirection = None
14521475
attn_metadata.beam_width = 1
14531476

14541477
attn_metadata.request_ids = request_ids

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,11 +375,15 @@ def add_dummy_requests(
375375
prepare_resource: bool = True,
376376
max_num_draft_tokens: int = 0,
377377
use_mrope: bool = False,
378+
max_beam_width: int = 1,
378379
):
379-
beam_width = 1 # TODO: more than 1 beam?
380+
beam_width = max_beam_width
380381
requests = []
381382
for i, req_id in enumerate(request_ids):
382-
sampling_params = SamplingParams()
383+
# exact choice of n can be ignored for dummy requests
384+
sampling_params = SamplingParams(n=beam_width,
385+
best_of=beam_width,
386+
use_beam_search=beam_width > 1)
383387
# Here 1+max_num_draft_tokens is used to extend the prompt length to
384388
# a non-zero number to skip illegal memory access issue in MLA kernel
385389
# during warmup.

tests/unittest/_torch/test_beam_search.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from utils.util import force_ampere, similar
66

77
from tensorrt_llm import LLM, SamplingParams
8-
from tensorrt_llm.llmapi.llm_utils import KvCacheConfig
8+
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig
99

1010

1111
@pytest.fixture(scope="module")
@@ -46,13 +46,12 @@ def llm(fixed_params, input_prompts):
4646
enable_trtllm_sampler=True,
4747
max_beam_width=fixed_params["max_beam_width"],
4848
disable_overlap_scheduler=True,
49-
#TODO: remove this once we have a proper fix for CUDA graph in beam search
5049
cuda_graph_config=None,
5150
)
5251

5352

5453
@pytest.fixture(scope="module")
55-
def llm_overlap(fixed_params, input_prompts):
54+
def llm_cuda_graph(fixed_params, input_prompts):
5655
return LLM(
5756
model=os.path.join(llm_models_root(), "llama-models-v2",
5857
"TinyLlama-1.1B-Chat-v1.0"),
@@ -64,8 +63,7 @@ def llm_overlap(fixed_params, input_prompts):
6463
enable_trtllm_sampler=True,
6564
max_beam_width=fixed_params["max_beam_width"],
6665
disable_overlap_scheduler=False,
67-
#TODO: remove this once we have a proper fix for CUDA graph in beam search
68-
cuda_graph_config=None,
66+
cuda_graph_config=CudaGraphConfig(enabled=True),
6967
)
7068

7169

@@ -132,10 +130,10 @@ def test_beam_search_output_shapes(gather_context_logits: bool,
132130
@pytest.mark.parametrize("num_output_beams", [1, 2])
133131
@pytest.mark.parametrize("num_prompts", [1, 2])
134132
@pytest.mark.threadleak(enabled=False)
135-
def test_beam_search_output_shapes_overlap(
133+
def test_beam_search_output_shapes_cuda_graph_and_overlap(
136134
gather_context_logits: bool, gather_generation_logits: bool,
137135
return_log_probs: bool, num_output_beams: int, num_prompts: int,
138-
llm_overlap, fixed_params, input_prompts, expected_outputs):
136+
llm_cuda_graph, fixed_params, input_prompts, expected_outputs):
139137
if return_log_probs and num_prompts > 1:
140138
pytest.skip(
141139
"Beam search currently does not support return_log_probs with multiple prompts"
@@ -149,8 +147,8 @@ def test_beam_search_output_shapes_overlap(
149147
return_generation_logits=gather_generation_logits,
150148
logprobs=return_log_probs,
151149
)
152-
outputs = llm_overlap.generate(input_prompts[:num_prompts],
153-
sampling_params=sampling_params)
150+
outputs = llm_cuda_graph.generate(input_prompts[:num_prompts],
151+
sampling_params=sampling_params)
154152
assert len(outputs) == num_prompts
155153
for output_idx, output in enumerate(outputs):
156154
if gather_context_logits:

0 commit comments

Comments
 (0)