Skip to content

Commit 3f37164

Browse files
committed
[TRTLLM-6756][chore] Fixed wrong shape, when not using beam search
Unsqueeze buffer returned from sampling to always contain the beam_width dimension Signed-off-by: Stefan Niebler <[email protected]>
1 parent 2fcd6c0 commit 3f37164

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,9 @@ def get_spec_tree_manager(self, resource_manager: ResourceManager) -> Optional[S
665665
return None
666666
return spec_resource_manager.spec_tree_manager
667667

668+
def _use_beam_search(self) -> bool:
669+
return self.max_beam_width > 1
670+
668671
def _meet_max_token_stop_criteria(self, request: LlmRequest, beam: int = 0) -> bool:
669672
num_tokens = request.get_num_tokens(beam)
670673
return (num_tokens - request.py_orig_prompt_len >= request.py_max_new_tokens) or (
@@ -737,7 +740,7 @@ def _process_draft_tokens_greedy(
737740
new_tokens: list[list[list[int]]],
738741
) -> int:
739742
new_token = add_token(request, new_tokens, beam=self.BEAM)
740-
stop = self._handle_stop_criteria(request, new_token)
743+
stop = self._handle_stop_criteria(request, new_token, beam=self.BEAM)
741744
if stop or get_draft_token_length(request) == 0:
742745
return 0
743746
num_accepted = 0
@@ -749,7 +752,7 @@ def _process_draft_tokens_greedy(
749752

750753
num_accepted += 1
751754
new_token = add_token(request, new_tokens, beam=self.BEAM, step=num_accepted)
752-
if self._handle_stop_criteria(request, new_token):
755+
if self._handle_stop_criteria(request, new_token, beam=self.BEAM):
753756
break
754757
return num_accepted
755758

@@ -847,7 +850,7 @@ def _process_draft_tokens_tree(
847850
request, new_tokens_list, beam=0, step=cast(int, idx.item())
848851
)
849852
num_accepted_draft_tokens += 1
850-
if self._handle_stop_criteria(request, new_token):
853+
if self._handle_stop_criteria(request, new_token, beam=self.BEAM):
851854
break
852855

853856
return num_accepted_draft_tokens - 1
@@ -995,7 +998,7 @@ def _process_draft_tokens_rejection_sampling(
995998
new_token = request.py_draft_tokens[i]
996999
new_tokens_tensor[i, request.seq_slot, self.BEAM] = new_token
9971000
request.add_new_token(new_token, self.BEAM)
998-
stop = self._handle_stop_criteria(request, new_token)
1001+
stop = self._handle_stop_criteria(request, new_token, beam=self.BEAM)
9991002
if stop:
10001003
num_accepted = i + 1
10011004
return num_accepted
@@ -1005,7 +1008,7 @@ def _process_draft_tokens_rejection_sampling(
10051008
request.add_new_token(new_token, self.BEAM)
10061009
else:
10071010
new_token = add_token(request, new_tokens_list, beam=self.BEAM, step=num_accepted)
1008-
stop = self._handle_stop_criteria(request, new_token)
1011+
stop = self._handle_stop_criteria(request, new_token, beam=self.BEAM)
10091012

10101013
return num_accepted
10111014

@@ -1034,7 +1037,9 @@ def process_draft_tokens(
10341037
)
10351038
return num_accepted
10361039
else:
1037-
return self._process_draft_tokens_rejection_sampling(request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor)
1040+
return self._process_draft_tokens_rejection_sampling(
1041+
request, new_tokens_list=new_tokens_list, new_tokens_tensor=new_tokens_tensor
1042+
)
10381043

10391044
def _update_beam_history(self, request: LlmRequest) -> None:
10401045
"""Correct the stored tokens for each beam
@@ -1095,7 +1100,7 @@ def update_requests(
10951100
beams_finished += 1
10961101
self.handle_logprobs(req, state, beam=beam, count=1)
10971102
req.py_decoding_iter += 1
1098-
if beams_finished == req.sampling_config.beam_width:
1103+
if self._use_beam_search() and beams_finished == req.sampling_config.beam_width:
10991104
self._remove_active_request(req)
11001105
assert beams_finished == 0 or beams_finished == req.sampling_config.beam_width, (
11011106
"Partially finished beams are not supported yet."
@@ -1128,12 +1133,11 @@ def update_requests(
11281133
else:
11291134
processed = 1
11301135
num_accepted = self.process_draft_tokens(
1131-
1132-
req,
1133-
new_tokens_tensor=new_tokens,
1134-
new_tokens_list=new_tokens_list,
1135-
state.host.new_tokens, resource_manager=resource_manager,
1136-
)
1136+
req,
1137+
new_tokens_tensor=new_tokens,
1138+
new_tokens_list=new_tokens_list,
1139+
resource_manager=resource_manager,
1140+
)
11371141
if get_draft_token_length(req) > 0:
11381142
req.py_num_accepted_draft_tokens = num_accepted
11391143
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
@@ -1164,7 +1168,7 @@ def sample_async(
11641168
# tokens are sampled one-by-one.
11651169

11661170
requests = scheduled_requests.all_requests()
1167-
if self.max_beam_width > 1:
1171+
if self._use_beam_search:
11681172
self._prepare_beam_search(requests)
11691173
new_tokens = self.store.new_tokens
11701174
return_log_probs = self.return_log_probs(scheduled_requests)
@@ -1178,7 +1182,7 @@ def sample_async(
11781182
torch.tensor(
11791183
[r.get_num_tokens(0) for r in requests], dtype=torch.int32, pin_memory=True
11801184
)
1181-
if self.max_beam_width > 1
1185+
if self._use_beam_search
11821186
else None
11831187
)
11841188
new_tokens_host = self._process_requests(
@@ -1405,6 +1409,9 @@ def _sample_batched_by_strategy(
14051409
batch_next_tokens_offset_end = (
14061410
batch_next_tokens_offset_start + group_next_tokens_cuda.size(0)
14071411
)
1412+
# if no beam search is used, the shape is (batch_size,), so we need to unsqueeze it to (batch_size, 1)
1413+
if group_next_tokens_cuda.dim() == 1:
1414+
group_next_tokens_cuda = group_next_tokens_cuda.unsqueeze(1)
14081415
batch_next_tokens_cuda_int[
14091416
batch_next_tokens_offset_start:batch_next_tokens_offset_end
14101417
].copy_(group_next_tokens_cuda, non_blocking=True)

0 commit comments

Comments
 (0)