Skip to content

Commit eb265ac

Browse files
committed
[TRTLLM-6756][test] Updated test_beam_search.py to correctly test for the updated beam search in TorchSampler
- enhanced logprob testing to verify sum(logprobs) == cum_log_probs - added testing for stop tokens Signed-off-by: Stefan Niebler <[email protected]>
1 parent 1528e92 commit eb265ac

File tree

2 files changed

+245
-79
lines changed

2 files changed

+245
-79
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,23 +1331,34 @@ def _create_beam_history(
13311331
]
13321332
new_path = torch.zeros_like(current_path)
13331333
if request.py_return_log_probs:
1334-
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(request)
1335-
# concatenate the newly generated logprobs and newly
1336-
# generated tokens to the current logprobs and logprobs indices
1337-
current_logprobs = torch.cat(
1338-
[
1339-
current_logprobs,
1340-
self.store.new_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1),
1341-
],
1342-
dim=1,
1343-
)
1344-
current_logprobs_indices = torch.cat(
1345-
[
1346-
current_logprobs_indices,
1347-
self.store.new_tokens[0, request.py_seq_slot, :num_beams].view(-1, 1, 1),
1348-
],
1349-
dim=1,
1350-
)
1334+
# Check that logprobs are initialized in the request
1335+
if getattr(request.py_result._log_probs, "log_probs", None) is not None:
1336+
current_logprobs, current_logprobs_indices = self._get_logprobs_from_request(
1337+
request
1338+
)
1339+
# concatenate the newly generated logprobs and newly
1340+
# generated tokens to the current logprobs and logprobs indices
1341+
current_logprobs = torch.cat(
1342+
[
1343+
current_logprobs,
1344+
self.store.new_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1),
1345+
],
1346+
dim=1,
1347+
)
1348+
current_logprobs_indices = torch.cat(
1349+
[
1350+
current_logprobs_indices,
1351+
self.store.new_tokens[0, request.py_seq_slot, :num_beams].view(-1, 1, 1),
1352+
],
1353+
dim=1,
1354+
)
1355+
else:
1356+
current_logprobs = self.store.new_log_probs[request.py_seq_slot, :num_beams].view(
1357+
-1, 1, 1
1358+
)
1359+
current_logprobs_indices = self.store.new_tokens[
1360+
0, request.py_seq_slot, :num_beams
1361+
].view(-1, 1, 1)
13511362
# Initialize the buffers to store the results
13521363
new_logprobs = torch.zeros_like(current_logprobs)
13531364
new_logprobs_indices = torch.zeros_like(current_logprobs_indices)

0 commit comments

Comments
 (0)