@@ -392,9 +392,6 @@ def __init__(
392392 self ._cuda_graphs = {}
393393 self ._cuda_graph_mem_pool = self ._torch_compile_backend ._graph_pool_handle if self ._torch_compile_enabled else None
394394 self ._run_cuda_graphs = pytorch_backend_config .use_cuda_graph
395- if self ._run_cuda_graphs and self .max_beam_width > 1 :
396- raise NotImplementedError (
397- "CUDA Graph + beam search is not implemented yet." )
398395
399396 self ._cuda_graph_padding_enabled = pytorch_backend_config .cuda_graph_padding_enabled
400397
@@ -425,6 +422,17 @@ def __init__(
425422 self .lora_model_config : Optional [LoraModelConfig ] = None
426423 self .cuda_graph_dummy_request = None
427424
425+ # Setup the local cache indirection buffer only once and reuse it.
426+ # This way it can also be used for CUDA graphs.
427+ if self .use_beam_search :
428+ self .cache_indirection_attention = torch .zeros (
429+ (self .batch_size , self .max_beam_width , self .max_seq_len +
430+ (0 if self ._disable_overlap_scheduler else 1 )),
431+ device = "cuda" ,
432+ dtype = torch .int32 )
433+ else :
434+ self .cache_indirection_attention = None
435+
428436 def set_lora_model_config (self , lora_target_modules : list [str ],
429437 trtllm_modules_to_hf_modules : dict [str , str ]):
430438 self .lora_model_config = LoraModelConfig (
@@ -444,6 +452,10 @@ def use_mrope(self):
444452 logger .info (f"Detected use_mrope: { use_mrope } " )
445453 return use_mrope
446454
455+ @property
456+ def use_beam_search (self ):
457+ return self .max_beam_width > 1
458+
447459 @contextmanager
448460 def set_warmup_flag (self ):
449461 self .in_warmup = True
@@ -487,7 +499,9 @@ def warmup(self, resource_manager: ResourceManager) -> None:
487499 self .cuda_graph_dummy_request = None
488500
489501 def get_cuda_graph_warmup_request (batch_size ):
490- available_blocks = kv_cache_manager .get_num_free_blocks ()
502+ # Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
503+ available_blocks = kv_cache_manager .get_num_free_blocks (
504+ ) // self .max_beam_width
491505 if available_blocks >= batch_size :
492506 result = ScheduledRequests ()
493507 result .context_requests = []
@@ -498,9 +512,10 @@ def get_cuda_graph_warmup_request(batch_size):
498512 is_gen = True ,
499513 max_num_draft_tokens = self .max_draft_len ,
500514 use_mrope = use_mrope ,
501- )
515+ max_beam_width = self .max_beam_width )
516+ # Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
502517 available_tokens = kv_cache_manager .get_num_available_tokens (
503- self .max_draft_len )
518+ self .max_draft_len ) // self . max_beam_width
504519
505520 # Add one dummy request with the maximum possible sequence length.
506521 # The sequence length is limited by both the max_seq_len and the number of available blocks.
@@ -511,7 +526,7 @@ def get_cuda_graph_warmup_request(batch_size):
511526 is_gen = True ,
512527 max_num_draft_tokens = self .max_draft_len ,
513528 use_mrope = use_mrope ,
514- )[0 ]
529+ max_beam_width = self . max_beam_width )[0 ]
515530 # Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
516531 # This batch contains both the longest request and the shortest requests,
517532 # it also contains the maximum number of requests and the maximum token number,
@@ -739,6 +754,7 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
739754 self .model .model_config .pretrained_config ) and (
740755 self .attn_runtime_features .cache_reuse
741756 or self .attn_runtime_features .chunked_prefill )
757+ cache_indirection = self .cache_indirection_attention if self .attn_backend .Metadata is TrtllmAttentionMetadata else None
742758 if kv_cache_manager is None :
743759 return self .attn_backend .Metadata (
744760 max_num_requests = self .batch_size ,
@@ -748,7 +764,8 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
748764 mapping = self .mapping ,
749765 runtime_features = self .attn_runtime_features ,
750766 enable_flash_mla = self .model .model_config .enable_flash_mla ,
751- enable_paged_context_mla = enable_paged_context_mla )
767+ enable_paged_context_mla = enable_paged_context_mla ,
768+ cache_indirection = cache_indirection )
752769
753770 if self .attn_metadata is not None :
754771 # This assertion can be relaxed if needed: just create a new metadata
@@ -764,7 +781,9 @@ def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
764781 mapping = self .mapping ,
765782 runtime_features = self .attn_runtime_features ,
766783 enable_flash_mla = self .model .model_config .enable_flash_mla ,
767- enable_paged_context_mla = enable_paged_context_mla )
784+ enable_paged_context_mla = enable_paged_context_mla ,
785+ cache_indirection = cache_indirection )
786+
768787 return self .attn_metadata
769788
770789 def _set_up_spec_metadata (
@@ -795,7 +814,8 @@ def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
795814 kv_cache_manager ) -> int :
796815 can_run_cuda_graph = scheduled_requests .can_run_cuda_graph
797816 batch_size = scheduled_requests .batch_size
798- new_batch_size = batch_size
817+ # The number of sequences in the batch is the number of prompts times the beam width.
818+ new_batch_size = batch_size * self .max_beam_width
799819 if self ._run_cuda_graphs and self .enable_attention_dp and self .mapping .tp_size > 1 :
800820 graph_batch_size = self .dist .tp_allgather (
801821 [can_run_cuda_graph , batch_size ])
@@ -831,7 +851,8 @@ def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
831851 [MAX_UINT64 - 1 ],
832852 is_gen = True ,
833853 max_num_draft_tokens = self .max_draft_len ,
834- use_mrope = self .use_mrope )[0 ]
854+ use_mrope = self .use_mrope ,
855+ max_beam_width = self .max_beam_width )[0 ]
835856 self .cuda_graph_dummy_request .is_cuda_graph_dummy = True
836857
837858 scheduled_requests .generation_requests .extend (
@@ -903,19 +924,21 @@ def _maybe_get_cuda_graph(
903924 if batch_size not in self ._cuda_graph_batch_sizes :
904925 return None
905926
927+ num_sequences_in_batch = batch_size * self .max_beam_width
906928 attn_metadata = self .attn_metadata .create_cuda_graph_metadata (
907- batch_size , False , spec_max_draft_tokens )
929+ num_sequences_in_batch , False , spec_max_draft_tokens )
908930 assert attn_metadata .is_cuda_graph
909931
910932 if self .is_spec_decode :
911933 spec_metadata = self .spec_metadata .create_cuda_graph_metadata (
912- batch_size )
934+ num_sequences_in_batch )
913935 spec_metadata .draft_tokens = self .draft_tokens_cuda
914936 else :
915937 spec_metadata = None
916938
917939 self ._cuda_graphs [batch_size ] = DecodingCUDAGraphRunner (
918- batch_size , "cuda" , attn_metadata , spec_metadata , self .use_mrope )
940+ num_sequences_in_batch , "cuda" , attn_metadata , spec_metadata ,
941+ self .use_mrope )
919942 return self ._cuda_graphs [batch_size ]
920943
921944 def __del__ (self ) -> None :
@@ -1439,16 +1462,16 @@ def previous_seq_slots_device():
14391462
14401463 num_generation_requests = len (scheduled_requests .generation_requests )
14411464 # Cache indirection is only used for beam search on generation requests
1442- if self .max_beam_width > 1 and num_generation_requests > 0 and cache_indirection_buffer is not None :
1443- cache_indirection_attention = torch .zeros_like (
1444- cache_indirection_buffer )
1445- #Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
1446- cache_indirection_attention [:num_generation_requests ].copy_ (
1447- cache_indirection_buffer [gen_request_seq_slots ])
1448- attn_metadata .cache_indirection = cache_indirection_attention
1449- attn_metadata .beam_width = self .max_beam_width
1465+ if self .use_beam_search and num_generation_requests > 0 :
1466+ # CUDA Graph needs to set beam width during warmup (where the graph is captured), to ensure that cache indirection buffer is correctly picked up by the CUDA graph
1467+ is_cuda_graph_during_warmup = self .in_warmup and attn_metadata .is_cuda_graph
1468+ if cache_indirection_buffer is not None :
1469+ #Copy cache indirection to local buffer with offsets changing: seq_slots[i] -> i
1470+ self .cache_indirection_attention [:num_generation_requests ].copy_ (
1471+ cache_indirection_buffer [gen_request_seq_slots ])
1472+ if cache_indirection_buffer is not None or is_cuda_graph_during_warmup :
1473+ attn_metadata .beam_width = self .max_beam_width
14501474 else :
1451- attn_metadata .cache_indirection = None
14521475 attn_metadata .beam_width = 1
14531476
14541477 attn_metadata .request_ids = request_ids
0 commit comments