@@ -665,6 +665,9 @@ def get_spec_tree_manager(self, resource_manager: ResourceManager) -> Optional[S
665665 return None
666666 return spec_resource_manager .spec_tree_manager
667667
668+ def _use_beam_search (self ) -> bool :
669+ return self .max_beam_width > 1
670+
668671 def _meet_max_token_stop_criteria (self , request : LlmRequest , beam : int = 0 ) -> bool :
669672 num_tokens = request .get_num_tokens (beam )
670673 return (num_tokens - request .py_orig_prompt_len >= request .py_max_new_tokens ) or (
@@ -737,7 +740,7 @@ def _process_draft_tokens_greedy(
737740 new_tokens : list [list [list [int ]]],
738741 ) -> int :
739742 new_token = add_token (request , new_tokens , beam = self .BEAM )
740- stop = self ._handle_stop_criteria (request , new_token )
743+ stop = self ._handle_stop_criteria (request , new_token , beam = self . BEAM )
741744 if stop or get_draft_token_length (request ) == 0 :
742745 return 0
743746 num_accepted = 0
@@ -749,7 +752,7 @@ def _process_draft_tokens_greedy(
749752
750753 num_accepted += 1
751754 new_token = add_token (request , new_tokens , beam = self .BEAM , step = num_accepted )
752- if self ._handle_stop_criteria (request , new_token ):
755+ if self ._handle_stop_criteria (request , new_token , beam = self . BEAM ):
753756 break
754757 return num_accepted
755758
@@ -847,7 +850,7 @@ def _process_draft_tokens_tree(
847850 request , new_tokens_list , beam = 0 , step = cast (int , idx .item ())
848851 )
849852 num_accepted_draft_tokens += 1
850- if self ._handle_stop_criteria (request , new_token ):
853+ if self ._handle_stop_criteria (request , new_token , beam = self . BEAM ):
851854 break
852855
853856 return num_accepted_draft_tokens - 1
@@ -995,7 +998,7 @@ def _process_draft_tokens_rejection_sampling(
995998 new_token = request .py_draft_tokens [i ]
996999 new_tokens_tensor [i , request .seq_slot , self .BEAM ] = new_token
9971000 request .add_new_token (new_token , self .BEAM )
998- stop = self ._handle_stop_criteria (request , new_token )
1001+ stop = self ._handle_stop_criteria (request , new_token , beam = self . BEAM )
9991002 if stop :
10001003 num_accepted = i + 1
10011004 return num_accepted
@@ -1005,7 +1008,7 @@ def _process_draft_tokens_rejection_sampling(
10051008 request .add_new_token (new_token , self .BEAM )
10061009 else :
10071010 new_token = add_token (request , new_tokens_list , beam = self .BEAM , step = num_accepted )
1008- stop = self ._handle_stop_criteria (request , new_token )
1011+ stop = self ._handle_stop_criteria (request , new_token , beam = self . BEAM )
10091012
10101013 return num_accepted
10111014
@@ -1034,7 +1037,9 @@ def process_draft_tokens(
10341037 )
10351038 return num_accepted
10361039 else :
1037- return self ._process_draft_tokens_rejection_sampling (request , new_tokens_list = new_tokens_list , new_tokens_tensor = new_tokens_tensor )
1040+ return self ._process_draft_tokens_rejection_sampling (
1041+ request , new_tokens_list = new_tokens_list , new_tokens_tensor = new_tokens_tensor
1042+ )
10381043
10391044 def _update_beam_history (self , request : LlmRequest ) -> None :
10401045 """Correct the stored tokens for each beam
@@ -1095,7 +1100,7 @@ def update_requests(
10951100 beams_finished += 1
10961101 self .handle_logprobs (req , state , beam = beam , count = 1 )
10971102 req .py_decoding_iter += 1
1098- if beams_finished == req .sampling_config .beam_width :
1103+ if self . _use_beam_search () and beams_finished == req .sampling_config .beam_width :
10991104 self ._remove_active_request (req )
11001105 assert beams_finished == 0 or beams_finished == req .sampling_config .beam_width , (
11011106 "Partially finished beams are not supported yet."
@@ -1128,12 +1133,11 @@ def update_requests(
11281133 else :
11291134 processed = 1
11301135 num_accepted = self .process_draft_tokens (
1131-
1132- req ,
1133- new_tokens_tensor = new_tokens ,
1134- new_tokens_list = new_tokens_list ,
1135- state .host .new_tokens , resource_manager = resource_manager ,
1136- )
1136+ req ,
1137+ new_tokens_tensor = new_tokens ,
1138+ new_tokens_list = new_tokens_list ,
1139+ resource_manager = resource_manager ,
1140+ )
11371141 if get_draft_token_length (req ) > 0 :
11381142 req .py_num_accepted_draft_tokens = num_accepted
11391143 req .py_rewind_len = req .py_draft_pages_allocated - num_accepted
@@ -1164,7 +1168,7 @@ def sample_async(
11641168 # tokens are sampled one-by-one.
11651169
11661170 requests = scheduled_requests .all_requests ()
1167- if self .max_beam_width > 1 :
1171+ if self ._use_beam_search :
11681172 self ._prepare_beam_search (requests )
11691173 new_tokens = self .store .new_tokens
11701174 return_log_probs = self .return_log_probs (scheduled_requests )
@@ -1178,7 +1182,7 @@ def sample_async(
11781182 torch .tensor (
11791183 [r .get_num_tokens (0 ) for r in requests ], dtype = torch .int32 , pin_memory = True
11801184 )
1181- if self .max_beam_width > 1
1185+ if self ._use_beam_search
11821186 else None
11831187 )
11841188 new_tokens_host = self ._process_requests (
@@ -1405,6 +1409,9 @@ def _sample_batched_by_strategy(
14051409 batch_next_tokens_offset_end = (
14061410 batch_next_tokens_offset_start + group_next_tokens_cuda .size (0 )
14071411 )
1412+ # if no beam search is used, the shape is (batch_size,), so we need to unsqueeze it to (batch_size, 1)
1413+ if group_next_tokens_cuda .dim () == 1 :
1414+ group_next_tokens_cuda = group_next_tokens_cuda .unsqueeze (1 )
14081415 batch_next_tokens_cuda_int [
14091416 batch_next_tokens_offset_start :batch_next_tokens_offset_end
14101417 ].copy_ (group_next_tokens_cuda , non_blocking = True )
0 commit comments