Skip to content

Commit 67593dd

Browse files
committed
[TRTLLM-6756][refactor] Enhance beam search handling in TorchSampler and sampling utilities
- Introduced new functions to retrieve beam width parameters for input and output, improving clarity and modularity. - Updated UtilsSamplingParams to include separate beam width parameters and a flag for beam search usage. - Refactored beam search sampling logic to accommodate changes in beam width handling, ensuring compatibility with new parameters. - Unified beam search sampling for context and generation requests - Simplified code for beam history creation - Adjusted test cases to reflect changes in beam width handling and improved logprob validation. Signed-off-by: Stefan Niebler <[email protected]>
1 parent 2c734eb commit 67593dd

File tree

5 files changed

+215
-244
lines changed

5 files changed

+215
-244
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 76 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -274,20 +274,42 @@ def _unwrap_singleton(p: Optional[List[T]]) -> Optional[T]:
274274
return t
275275

276276

277+
def _get_beam_width_in(request: LlmRequest) -> int:
278+
return (
279+
1
280+
if request.is_context_init_state
281+
else request.get_beam_width_by_iter(for_next_iteration=False)
282+
)
283+
284+
285+
def _get_beam_width_out(request: LlmRequest) -> int:
286+
return request.get_beam_width_by_iter(for_next_iteration=True)
287+
288+
289+
def _get_max_beam_width(request: LlmRequest) -> int:
290+
sampling_config = request.sampling_config
291+
max_beam_width = sampling_config.beam_width
292+
if sampling_config.beam_width_array is not None:
293+
max_beam_width = max(max_beam_width, sampling_config.beam_width_array.max())
294+
return max_beam_width
295+
296+
277297
def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams:
278298
sampling_config = request.sampling_config
279299
temperature = _unwrap_singleton(cast(Optional[list[float]], sampling_config.temperature))
280300
top_p = _unwrap_singleton(cast(Optional[list[float]], sampling_config.top_p))
281301
top_k = _unwrap_singleton(cast(Optional[list[int]], sampling_config.top_k))
282-
beam_width = sampling_config.beam_width
283-
is_context_init_state = request.is_context_init_state
302+
beam_width_out = _get_beam_width_out(request)
303+
beam_width_in = _get_beam_width_in(request)
304+
use_beam_search = _get_max_beam_width(request) > 1
284305

285306
return UtilsSamplingParams(
286307
temperature=temperature,
287308
top_p=top_p,
288309
top_k=top_k,
289-
beam_width=beam_width,
290-
is_context_init_state=is_context_init_state,
310+
beam_width_in=beam_width_in,
311+
beam_width_out=beam_width_out,
312+
use_beam_search=use_beam_search,
291313
)
292314

293315

@@ -933,7 +955,6 @@ def _convert_logprobs_tensor_to_list(
933955
def handle_logprobs(
934956
self,
935957
request: LlmRequest,
936-
state: SampleState,
937958
*,
938959
count: int,
939960
):
@@ -1095,7 +1116,7 @@ def setup_sampler_step(self, requests: ScheduledRequests):
10951116
requests: list[LlmRequest]. The requests to setup the sampler step for
10961117
"""
10971118
if self._use_beam_search:
1098-
self._prepare_beam_search(requests)
1119+
self._prepare_beam_search(requests.all_requests())
10991120

11001121
def _prepare_beam_search(
11011122
self,
@@ -1260,8 +1281,6 @@ def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor,
12601281
logprobs_tensor: A tensor of shape (beam_width, num_generated_tokens, num_logprobs)
12611282
logprobs_indices_tensor: A tensor of shape (beam_width, num_generated_tokens, num_logprobs)
12621283
"""
1263-
1264-
logprobs_list = request.py_result.log_probs
12651284
num_generated_tokens = request.get_num_tokens(0) - request.py_prompt_len
12661285
assert request.py_num_logprobs == 1, "Beam search only supports one logprob per token"
12671286
logprobs_tensor = torch.empty(
@@ -1282,17 +1301,19 @@ def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor,
12821301
device="cuda",
12831302
dtype=torch.int32,
12841303
)
1285-
for beam_idx, beam_logprobs in enumerate(logprobs_list):
1286-
for token_idx, token_logprobs in enumerate(beam_logprobs):
1287-
for key, value in token_logprobs.items():
1288-
logprobs_tensor[beam_idx, token_idx, value.rank - 1] = value.logprob
1289-
logprobs_indices_tensor[beam_idx, token_idx, value.rank - 1] = key
1304+
if hasattr(request.py_result._log_probs, "log_probs"):
1305+
logprobs_list = request.py_result.log_probs
1306+
for beam_idx, beam_logprobs in enumerate(logprobs_list):
1307+
for token_idx, token_logprobs in enumerate(beam_logprobs):
1308+
for key, value in token_logprobs.items():
1309+
logprobs_tensor[beam_idx, token_idx, value.rank - 1] = value.logprob
1310+
logprobs_indices_tensor[beam_idx, token_idx, value.rank - 1] = key
12901311
return logprobs_tensor, logprobs_indices_tensor
12911312

12921313
def _create_beam_history(
12931314
self,
12941315
request: LlmRequest,
1295-
) -> BeamHistory:
1316+
) -> BeamHistory | None:
12961317
"""Correct the stored tokens for each beam and return it as a BeamHistory object.
12971318
12981319
Beam Search sampling only adds new tokens to the beam.
@@ -1311,12 +1332,7 @@ def _create_beam_history(
13111332

13121333
if num_generated_tokens == 0 or request.state == LlmRequestState.GENERATION_COMPLETE:
13131334
# early return if no tokens have been generated yet or the request is already finished
1314-
return BeamHistory(
1315-
tokens=None,
1316-
logprobs=None,
1317-
logprobs_indices=None,
1318-
cum_logprobs=None,
1319-
)
1335+
return None
13201336
cache_indirection = self.store.cache_indirection[
13211337
request.py_seq_slot, :num_beams, prompt_length:num_tokens
13221338
]
@@ -1325,58 +1341,47 @@ def _create_beam_history(
13251341
]
13261342
new_path = torch.zeros_like(current_path)
13271343
if request.py_return_log_probs:
1328-
# Check that logprobs are initialized in the request
1329-
if getattr(request.py_result._log_probs, "log_probs", None) is not None:
1330-
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(
1331-
request
1332-
)
1333-
# concatenate the newly generated logprobs and newly
1334-
# generated tokens to the current logprobs and logprobs indices
1335-
current_logprobs = torch.cat(
1336-
[
1337-
current_logprobs,
1338-
self.store.new_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1),
1339-
],
1340-
dim=1,
1341-
)
1342-
current_logprobs_indices = torch.cat(
1343-
[
1344-
current_logprobs_indices,
1345-
self.store.new_tokens[0, request.py_seq_slot, :num_beams].view(-1, 1, 1),
1346-
],
1347-
dim=1,
1348-
)
1349-
else:
1350-
current_logprobs = self.store.new_log_probs[request.py_seq_slot, :num_beams].view(
1351-
-1, 1, 1
1352-
)
1353-
current_logprobs_indices = self.store.new_tokens[
1354-
0, request.py_seq_slot, :num_beams
1355-
].view(-1, 1, 1)
1344+
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(request)
1345+
# concatenate the newly generated logprobs and newly
1346+
# generated tokens to the current logprobs and logprobs indices
1347+
current_logprobs = torch.cat(
1348+
[
1349+
current_logprobs,
1350+
self.store.new_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1),
1351+
],
1352+
dim=1,
1353+
)
1354+
current_logprobs_indices = torch.cat(
1355+
[
1356+
current_logprobs_indices,
1357+
self.store.new_tokens[0, request.py_seq_slot, :num_beams].view(-1, 1, 1),
1358+
],
1359+
dim=1,
1360+
)
13561361
# Initialize the buffers to store the results
13571362
new_logprobs = torch.zeros_like(current_logprobs)
13581363
new_logprobs_indices = torch.zeros_like(current_logprobs_indices)
13591364
# initialize each beam with its own index
1360-
basic_beams = torch.arange(num_beams, device=cache_indirection.device, dtype=torch.int32)
1361-
# Traverse the cache indirection backwards to obtain the correct tokens and logprobsfor each beam.
1362-
for token_idx in range(num_generated_tokens - 1, 0, -1):
1363-
active_beams = cache_indirection[basic_beams, token_idx]
1364-
# set the current token and logprob
1365-
new_path[:, token_idx] = current_path[active_beams, token_idx]
1366-
if request.py_return_log_probs:
1367-
new_logprobs[:, token_idx] = current_logprobs[active_beams, token_idx]
1368-
new_logprobs_indices[:, token_idx] = current_logprobs_indices[
1369-
active_beams, token_idx
1370-
]
1371-
# update the active beams
1372-
active_beams = cache_indirection[basic_beams, 0]
1373-
# set the first generated token and logprob
1374-
new_path[:, 0] = current_path[active_beams, 0]
13751365

1366+
# Gather the correct tokens and logprobs for each beam
1367+
torch.gather(input=current_path, dim=0, index=cache_indirection, out=new_path)
13761368
if request.py_return_log_probs:
1377-
new_logprobs[:, 0] = current_logprobs[active_beams, 0]
1378-
new_logprobs_indices[:, 0] = current_logprobs_indices[active_beams, 0]
1379-
cum_logprobs = self.store.cum_log_probs[request.py_seq_slot, :]
1369+
cache_indirection_for_logprobs = cache_indirection.unsqueeze(-1).expand(
1370+
-1, -1, current_logprobs.shape[2]
1371+
)
1372+
torch.gather(
1373+
input=current_logprobs,
1374+
dim=0,
1375+
index=cache_indirection_for_logprobs,
1376+
out=new_logprobs,
1377+
)
1378+
torch.gather(
1379+
input=current_logprobs_indices,
1380+
dim=0,
1381+
index=cache_indirection_for_logprobs,
1382+
out=new_logprobs_indices,
1383+
)
1384+
cum_logprobs = self.store.cum_log_probs[request.py_seq_slot, :num_beams]
13801385
return BeamHistory(
13811386
tokens=new_path,
13821387
logprobs=new_logprobs,
@@ -1477,7 +1482,7 @@ def _add_metadata_to_grouped_requests(
14771482
grouped_requests_with_metadata: dict[RequestGroupKey, RequestGroupValueWithMetadata] = {}
14781483
for key, value in grouped_requests.items():
14791484
match key.strategy:
1480-
case ("beam_search", _, _) | ("beam_search_for_prefill", _, _):
1485+
case ("beam_search", _, _, _):
14811486
assert seq_lens is not None, "seq_lens is required for beam search"
14821487
metadata = BeamSearchMetadata(
14831488
cache_indirection=self.store.cache_indirection,
@@ -1584,7 +1589,7 @@ def update_requests(
15841589
else:
15851590
for beam_idx in range(req.sampling_config.beam_width):
15861591
add_token(req, new_tokens_list, beam_idx=beam_idx)
1587-
self.handle_logprobs(req, state, count=1)
1592+
self.handle_logprobs(req, count=1)
15881593
self._handle_finish_reasons(req, state.host.finish_reasons, finish_reasons)
15891594
req.py_decoding_iter += 1
15901595

@@ -1605,7 +1610,7 @@ def update_requests(
16051610
for beam_idx in range(req.sampling_config.beam_width):
16061611
# Beam search does not support speculative decoding.
16071612
add_token(req, new_tokens_list, beam_idx=beam_idx)
1608-
self.handle_logprobs(req, state, count=1)
1613+
self.handle_logprobs(req, count=1)
16091614
self._handle_finish_reasons(req, state.host.finish_reasons, finish_reasons)
16101615
req.py_num_accepted_draft_tokens = 0
16111616
req.py_rewind_len = 0
@@ -1627,7 +1632,7 @@ def update_requests(
16271632
req.py_num_accepted_draft_tokens = 0
16281633
req.py_rewind_len = 0
16291634
processed += num_accepted
1630-
self.handle_logprobs(req, state, count=processed)
1635+
self.handle_logprobs(req, count=processed)
16311636
req.py_decoding_iter += 1
16321637

16331638
def return_log_probs(self, scheduled_requests: ScheduledRequests) -> bool:
@@ -1648,10 +1653,8 @@ def sample_async(
16481653
# case there are 1 + get_draft_token_length(request) tokens per request. In the
16491654
# latter case, there is always only 1 token per request because draft
16501655
# tokens are sampled one-by-one.
1656+
self.setup_sampler_step(scheduled_requests)
16511657
requests = scheduled_requests.all_requests()
1652-
if self._use_beam_search:
1653-
# prepare the new beams for the current iteration
1654-
self._prepare_beam_search(requests)
16551658
new_tokens = self.store.new_tokens
16561659
return_log_probs = self.return_log_probs(scheduled_requests)
16571660
seq_slots_host = torch.tensor(

0 commit comments

Comments
 (0)