@@ -274,20 +274,42 @@ def _unwrap_singleton(p: Optional[List[T]]) -> Optional[T]:
274274 return t
275275
276276
277+ def _get_beam_width_in (request : LlmRequest ) -> int :
278+ return (
279+ 1
280+ if request .is_context_init_state
281+ else request .get_beam_width_by_iter (for_next_iteration = False )
282+ )
283+
284+
285+ def _get_beam_width_out (request : LlmRequest ) -> int :
286+ return request .get_beam_width_by_iter (for_next_iteration = True )
287+
288+
289+ def _get_max_beam_width (request : LlmRequest ) -> int :
290+ sampling_config = request .sampling_config
291+ max_beam_width = sampling_config .beam_width
292+ if sampling_config .beam_width_array is not None :
293+ max_beam_width = max (max_beam_width , sampling_config .beam_width_array .max ())
294+ return max_beam_width
295+
296+
277297def _request_get_sampling_params (request : LlmRequest ) -> UtilsSamplingParams :
278298 sampling_config = request .sampling_config
279299 temperature = _unwrap_singleton (cast (Optional [list [float ]], sampling_config .temperature ))
280300 top_p = _unwrap_singleton (cast (Optional [list [float ]], sampling_config .top_p ))
281301 top_k = _unwrap_singleton (cast (Optional [list [int ]], sampling_config .top_k ))
282- beam_width = sampling_config .beam_width
283- is_context_init_state = request .is_context_init_state
302+ beam_width_out = _get_beam_width_out (request )
303+ beam_width_in = _get_beam_width_in (request )
304+ use_beam_search = _get_max_beam_width (request ) > 1
284305
285306 return UtilsSamplingParams (
286307 temperature = temperature ,
287308 top_p = top_p ,
288309 top_k = top_k ,
289- beam_width = beam_width ,
290- is_context_init_state = is_context_init_state ,
310+ beam_width_in = beam_width_in ,
311+ beam_width_out = beam_width_out ,
312+ use_beam_search = use_beam_search ,
291313 )
292314
293315
@@ -933,7 +955,6 @@ def _convert_logprobs_tensor_to_list(
933955 def handle_logprobs (
934956 self ,
935957 request : LlmRequest ,
936- state : SampleState ,
937958 * ,
938959 count : int ,
939960 ):
@@ -1095,7 +1116,7 @@ def setup_sampler_step(self, requests: ScheduledRequests):
10951116 requests: list[LlmRequest]. The requests to setup the sampler step for
10961117 """
10971118 if self ._use_beam_search :
1098- self ._prepare_beam_search (requests )
1119+ self ._prepare_beam_search (requests . all_requests () )
10991120
11001121 def _prepare_beam_search (
11011122 self ,
@@ -1260,8 +1281,6 @@ def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor,
12601281 logprobs_tensor: A tensor of shape (beam_width, num_generated_tokens, num_logprobs)
12611282 logprobs_indices_tensor: A tensor of shape (beam_width, num_generated_tokens, num_logprobs)
12621283 """
1263-
1264- logprobs_list = request .py_result .log_probs
12651284 num_generated_tokens = request .get_num_tokens (0 ) - request .py_prompt_len
12661285 assert request .py_num_logprobs == 1 , "Beam search only supports one logprob per token"
12671286 logprobs_tensor = torch .empty (
@@ -1282,17 +1301,19 @@ def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor,
12821301 device = "cuda" ,
12831302 dtype = torch .int32 ,
12841303 )
1285- for beam_idx , beam_logprobs in enumerate (logprobs_list ):
1286- for token_idx , token_logprobs in enumerate (beam_logprobs ):
1287- for key , value in token_logprobs .items ():
1288- logprobs_tensor [beam_idx , token_idx , value .rank - 1 ] = value .logprob
1289- logprobs_indices_tensor [beam_idx , token_idx , value .rank - 1 ] = key
1304+ if hasattr (request .py_result ._log_probs , "log_probs" ):
1305+ logprobs_list = request .py_result .log_probs
1306+ for beam_idx , beam_logprobs in enumerate (logprobs_list ):
1307+ for token_idx , token_logprobs in enumerate (beam_logprobs ):
1308+ for key , value in token_logprobs .items ():
1309+ logprobs_tensor [beam_idx , token_idx , value .rank - 1 ] = value .logprob
1310+ logprobs_indices_tensor [beam_idx , token_idx , value .rank - 1 ] = key
12901311 return logprobs_tensor , logprobs_indices_tensor
12911312
12921313 def _create_beam_history (
12931314 self ,
12941315 request : LlmRequest ,
1295- ) -> BeamHistory :
1316+ ) -> BeamHistory | None :
12961317 """Correct the stored tokens for each beam and return it as a BeamHistory object.
12971318
12981319 Beam Search sampling only adds new tokens to the beam.
@@ -1311,12 +1332,7 @@ def _create_beam_history(
13111332
13121333 if num_generated_tokens == 0 or request .state == LlmRequestState .GENERATION_COMPLETE :
13131334 # early return if no tokens have been generated yet or the request is already finished
1314- return BeamHistory (
1315- tokens = None ,
1316- logprobs = None ,
1317- logprobs_indices = None ,
1318- cum_logprobs = None ,
1319- )
1335+ return None
13201336 cache_indirection = self .store .cache_indirection [
13211337 request .py_seq_slot , :num_beams , prompt_length :num_tokens
13221338 ]
@@ -1325,58 +1341,47 @@ def _create_beam_history(
13251341 ]
13261342 new_path = torch .zeros_like (current_path )
13271343 if request .py_return_log_probs :
1328- # Check that logprobs are initialized in the request
1329- if getattr (request .py_result ._log_probs , "log_probs" , None ) is not None :
1330- current_logprobs , current_logprobs_indices = self ._get_logprobs_from_request (
1331- request
1332- )
1333- # concatenate the newly generated logprobs and newly
1334- # generated tokens to the current logprobs and logprobs indices
1335- current_logprobs = torch .cat (
1336- [
1337- current_logprobs ,
1338- self .store .new_log_probs [request .py_seq_slot , :num_beams ].view (- 1 , 1 , 1 ),
1339- ],
1340- dim = 1 ,
1341- )
1342- current_logprobs_indices = torch .cat (
1343- [
1344- current_logprobs_indices ,
1345- self .store .new_tokens [0 , request .py_seq_slot , :num_beams ].view (- 1 , 1 , 1 ),
1346- ],
1347- dim = 1 ,
1348- )
1349- else :
1350- current_logprobs = self .store .new_log_probs [request .py_seq_slot , :num_beams ].view (
1351- - 1 , 1 , 1
1352- )
1353- current_logprobs_indices = self .store .new_tokens [
1354- 0 , request .py_seq_slot , :num_beams
1355- ].view (- 1 , 1 , 1 )
1344+ current_logprobs , current_logprobs_indices = self ._get_logprobs_from_request (request )
1345+ # concatenate the newly generated logprobs and newly
1346+ # generated tokens to the current logprobs and logprobs indices
1347+ current_logprobs = torch .cat (
1348+ [
1349+ current_logprobs ,
1350+ self .store .new_log_probs [request .py_seq_slot , :num_beams ].view (- 1 , 1 , 1 ),
1351+ ],
1352+ dim = 1 ,
1353+ )
1354+ current_logprobs_indices = torch .cat (
1355+ [
1356+ current_logprobs_indices ,
1357+ self .store .new_tokens [0 , request .py_seq_slot , :num_beams ].view (- 1 , 1 , 1 ),
1358+ ],
1359+ dim = 1 ,
1360+ )
13561361 # Initialize the buffers to store the results
13571362 new_logprobs = torch .zeros_like (current_logprobs )
13581363 new_logprobs_indices = torch .zeros_like (current_logprobs_indices )
13591364 # initialize each beam with its own index
1360- basic_beams = torch .arange (num_beams , device = cache_indirection .device , dtype = torch .int32 )
1361- # Traverse the cache indirection backwards to obtain the correct tokens and logprobsfor each beam.
1362- for token_idx in range (num_generated_tokens - 1 , 0 , - 1 ):
1363- active_beams = cache_indirection [basic_beams , token_idx ]
1364- # set the current token and logprob
1365- new_path [:, token_idx ] = current_path [active_beams , token_idx ]
1366- if request .py_return_log_probs :
1367- new_logprobs [:, token_idx ] = current_logprobs [active_beams , token_idx ]
1368- new_logprobs_indices [:, token_idx ] = current_logprobs_indices [
1369- active_beams , token_idx
1370- ]
1371- # update the active beams
1372- active_beams = cache_indirection [basic_beams , 0 ]
1373- # set the first generated token and logprob
1374- new_path [:, 0 ] = current_path [active_beams , 0 ]
13751365
1366+ # Gather the correct tokens and logprobs for each beam
1367+ torch .gather (input = current_path , dim = 0 , index = cache_indirection , out = new_path )
13761368 if request .py_return_log_probs :
1377- new_logprobs [:, 0 ] = current_logprobs [active_beams , 0 ]
1378- new_logprobs_indices [:, 0 ] = current_logprobs_indices [active_beams , 0 ]
1379- cum_logprobs = self .store .cum_log_probs [request .py_seq_slot , :]
1369+ cache_indirection_for_logprobs = cache_indirection .unsqueeze (- 1 ).expand (
1370+ - 1 , - 1 , current_logprobs .shape [2 ]
1371+ )
1372+ torch .gather (
1373+ input = current_logprobs ,
1374+ dim = 0 ,
1375+ index = cache_indirection_for_logprobs ,
1376+ out = new_logprobs ,
1377+ )
1378+ torch .gather (
1379+ input = current_logprobs_indices ,
1380+ dim = 0 ,
1381+ index = cache_indirection_for_logprobs ,
1382+ out = new_logprobs_indices ,
1383+ )
1384+ cum_logprobs = self .store .cum_log_probs [request .py_seq_slot , :num_beams ]
13801385 return BeamHistory (
13811386 tokens = new_path ,
13821387 logprobs = new_logprobs ,
@@ -1477,7 +1482,7 @@ def _add_metadata_to_grouped_requests(
14771482 grouped_requests_with_metadata : dict [RequestGroupKey , RequestGroupValueWithMetadata ] = {}
14781483 for key , value in grouped_requests .items ():
14791484 match key .strategy :
1480- case ("beam_search" , _, _) | ( "beam_search_for_prefill" , _ , _):
1485+ case ("beam_search" , _, _, _):
14811486 assert seq_lens is not None , "seq_lens is required for beam search"
14821487 metadata = BeamSearchMetadata (
14831488 cache_indirection = self .store .cache_indirection ,
@@ -1584,7 +1589,7 @@ def update_requests(
15841589 else :
15851590 for beam_idx in range (req .sampling_config .beam_width ):
15861591 add_token (req , new_tokens_list , beam_idx = beam_idx )
1587- self .handle_logprobs (req , state , count = 1 )
1592+ self .handle_logprobs (req , count = 1 )
15881593 self ._handle_finish_reasons (req , state .host .finish_reasons , finish_reasons )
15891594 req .py_decoding_iter += 1
15901595
@@ -1605,7 +1610,7 @@ def update_requests(
16051610 for beam_idx in range (req .sampling_config .beam_width ):
16061611 # Beam search does not support speculative decoding.
16071612 add_token (req , new_tokens_list , beam_idx = beam_idx )
1608- self .handle_logprobs (req , state , count = 1 )
1613+ self .handle_logprobs (req , count = 1 )
16091614 self ._handle_finish_reasons (req , state .host .finish_reasons , finish_reasons )
16101615 req .py_num_accepted_draft_tokens = 0
16111616 req .py_rewind_len = 0
@@ -1627,7 +1632,7 @@ def update_requests(
16271632 req .py_num_accepted_draft_tokens = 0
16281633 req .py_rewind_len = 0
16291634 processed += num_accepted
1630- self .handle_logprobs (req , state , count = processed )
1635+ self .handle_logprobs (req , count = processed )
16311636 req .py_decoding_iter += 1
16321637
16331638 def return_log_probs (self , scheduled_requests : ScheduledRequests ) -> bool :
@@ -1648,10 +1653,8 @@ def sample_async(
16481653 # case there are 1 + get_draft_token_length(request) tokens per request. In the
16491654 # latter case, there is always only 1 token per request because draft
16501655 # tokens are sampled one-by-one.
1656+ self .setup_sampler_step (scheduled_requests )
16511657 requests = scheduled_requests .all_requests ()
1652- if self ._use_beam_search :
1653- # prepare the new beams for the current iteration
1654- self ._prepare_beam_search (requests )
16551658 new_tokens = self .store .new_tokens
16561659 return_log_probs = self .return_log_probs (scheduled_requests )
16571660 seq_slots_host = torch .tensor (
0 commit comments