Skip to content

Commit 3e3f2ee

Browse files
committed
fix step
1 parent 2bc8bbf commit 3e3f2ee

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

examples/llm-api/llm_inference_logprob.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
1+
import torch
2+
13
from tensorrt_llm import LLM, SamplingParams
4+
from tensorrt_llm._tensorrt_engine import LLM as TrtLLM
25

36

47
def main():
58
llm = LLM(
6-
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
7-
gather_generation_logits=True # Required. TODO: Acutal API TBD.
9+
model="/scratch/llm-models/llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
10+
gather_generation_logits=True, # Required. TODO: Acutal API TBD.
11+
orchestrator_type="ray"
812
)
913

14+
# llm = TrtLLM(
15+
# model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
16+
# gather_generation_logits=True
17+
# )
18+
1019
# Sample prompts.
1120
prompts = [
1221
"Hello, my name is",
@@ -31,6 +40,12 @@ def main():
3140
print(f"Generated text: {output.outputs[0].text!r}")
3241
print(f"Generated token IDs: {output.outputs[0].token_ids}")
3342

43+
if output.outputs[0].generation_logits is not None:
44+
logits = output.outputs[0].generation_logits
45+
# Extract logits for the sampled tokens only
46+
sampled_logits = [logits[i, token_id].item() for i, token_id in enumerate(output.outputs[0].token_ids)]
47+
print(f"Logits of sampled tokens: {sampled_logits}")
48+
3449
if output.outputs[0].logprobs:
3550
print(f"\nLogprobs for each generated token:")
3651
for i, (token_id, token_logprobs) in enumerate(
@@ -39,7 +54,7 @@ def main():
3954
print(f"\n Token {i}: ID={token_id}, Text={llm.tokenizer.decode([token_id])!r}")
4055

4156
# TODO. move to proper test
42-
assert len(token_logprobs) == 1
57+
# assert len(token_logprobs) == 1
4358
assert token_id in token_logprobs, f"Sampled token {token_id} not in logprobs dict."
4459

4560
for tid, logprob_obj in token_logprobs.items():

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -756,9 +756,12 @@ def handle_logprobs(
756756
if request.py_return_generation_logits:
757757
generation_logits_storage = request.py_result._generation_logits
758758
if generation_logits_storage and generation_logits_storage._storage is not None:
759-
# Compute log_softmax to get logprobs for the sampled token
760-
# Iinternal storage tensor: [seq_length, beam_width, vocab_size]
761-
logits_for_step = generation_logits_storage._storage[step] # [beam_width, vocab_size]
759+
# Internal storage tensor: [seq_length, beam_width, vocab_size]
760+
# Calculate absolute step index in the generation sequence
761+
num_generated_tokens = len(request.get_tokens(beam)) - request.py_prompt_len
762+
absolute_step = num_generated_tokens - count + step
763+
764+
logits_for_step = generation_logits_storage._storage[absolute_step] # [beam_width, vocab_size]
762765
logprobs_for_step = F.log_softmax(logits_for_step[beam].float(), dim=-1)
763766
sampled_logprob = logprobs_for_step[sampled_token].item()
764767

0 commit comments

Comments
 (0)