Skip to content

Commit 2bc8bbf

Browse files
committed
WAR for sampled logprob
1 parent 07343bb commit 2bc8bbf

File tree

2 files changed

+93
-9
lines changed

2 files changed

+93
-9
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from tensorrt_llm import LLM, SamplingParams
2+
3+
4+
def main():
5+
llm = LLM(
6+
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
7+
gather_generation_logits=True # Required. TODO: Acutal API TBD.
8+
)
9+
10+
# Sample prompts.
11+
prompts = [
12+
"Hello, my name is",
13+
"The capital of France is",
14+
"The future of AI is",
15+
]
16+
17+
# Current behavior:
18+
# - With return_generation_logits=True: Returns ONLY the sampled token's logprob
19+
# - Without return_generation_logits=True: Returns top-K tokens (sampled token NOT guaranteed)
20+
sampling_params = SamplingParams(
21+
max_tokens=10,
22+
temperature=0.7,
23+
top_p=0.95,
24+
logprobs=1,
25+
return_generation_logits=True,
26+
)
27+
28+
for output in llm.generate(prompts, sampling_params):
29+
print(f"\n{'='*80}")
30+
print(f"Prompt: {output.prompt!r}")
31+
print(f"Generated text: {output.outputs[0].text!r}")
32+
print(f"Generated token IDs: {output.outputs[0].token_ids}")
33+
34+
if output.outputs[0].logprobs:
35+
print(f"\nLogprobs for each generated token:")
36+
for i, (token_id, token_logprobs) in enumerate(
37+
zip(output.outputs[0].token_ids, output.outputs[0].logprobs)
38+
):
39+
print(f"\n Token {i}: ID={token_id}, Text={llm.tokenizer.decode([token_id])!r}")
40+
41+
# TODO. move to proper test
42+
assert len(token_logprobs) == 1
43+
assert token_id in token_logprobs, f"Sampled token {token_id} not in logprobs dict."
44+
45+
for tid, logprob_obj in token_logprobs.items():
46+
token_text = llm.tokenizer.decode([tid])
47+
is_sampled = "← SAMPLED" if tid == token_id else ""
48+
print(f" • Token {tid:5d} ({token_text:15s}): "
49+
f"logprob={logprob_obj.logprob:8.4f}, "
50+
f"rank={logprob_obj.rank} {is_sampled}")
51+
print(f"{'='*80}\n")
52+
53+
54+
if __name__ == '__main__':
55+
main()

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -745,15 +745,44 @@ def handle_logprobs(
745745
topk_log_probs_vals = request.py_topk_logprobs_vals[:count]
746746
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]
747747

748-
token_log_probs = [
749-
{
750-
token: Logprob(logprob=logprob, rank=rank + 1)
751-
for rank, (token, logprob) in enumerate(
752-
zip(topk_token.tolist(), topk_logprob.tolist())
753-
)
754-
}
755-
for topk_token, topk_logprob in zip(topk_log_probs_indices, topk_log_probs_vals)
756-
]
748+
sampled_tokens = request.get_tokens(beam)[-count:]
749+
750+
token_log_probs = []
751+
for step, (topk_token, topk_logprob) in enumerate(zip(topk_log_probs_indices, topk_log_probs_vals)):
752+
sampled_token = sampled_tokens[step]
753+
754+
# TODO. WAR: If both gather_generation_logits and return_generation_logits are set,
755+
# return ONLY the sampled token's logprob (not top-K).
756+
if request.py_return_generation_logits:
757+
generation_logits_storage = request.py_result._generation_logits
758+
if generation_logits_storage and generation_logits_storage._storage is not None:
759+
# Compute log_softmax to get logprobs for the sampled token
760+
# Iinternal storage tensor: [seq_length, beam_width, vocab_size]
761+
logits_for_step = generation_logits_storage._storage[step] # [beam_width, vocab_size]
762+
logprobs_for_step = F.log_softmax(logits_for_step[beam].float(), dim=-1)
763+
sampled_logprob = logprobs_for_step[sampled_token].item()
764+
765+
rank = (logprobs_for_step > sampled_logprob).sum().item() + 1
766+
767+
step_dict = {sampled_token: Logprob(logprob=sampled_logprob, rank=rank)}
768+
else:
769+
step_dict = {
770+
token: Logprob(logprob=logprob, rank=rank + 1)
771+
for rank, (token, logprob) in enumerate(
772+
zip(topk_token.tolist(), topk_logprob.tolist())
773+
)
774+
}
775+
else:
776+
# Original behavior: return top-K
777+
step_dict = {
778+
token: Logprob(logprob=logprob, rank=rank + 1)
779+
for rank, (token, logprob) in enumerate(
780+
zip(topk_token.tolist(), topk_logprob.tolist())
781+
)
782+
}
783+
784+
token_log_probs.append(step_dict)
785+
757786
assert beam == 0, (
758787
"The following call relies on beam_width to be 1 - hence the list with a single element"
759788
)

0 commit comments

Comments
 (0)