Skip to content

Commit ede907b

Browse files
committed
fix step
1 parent 2bc8bbf commit ede907b

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

examples/llm-api/llm_inference_logprob.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
1+
import torch
2+
13
from tensorrt_llm import LLM, SamplingParams
4+
from tensorrt_llm._tensorrt_engine import LLM as TrtLLM
25

36

47
def main():
58
llm = LLM(
6-
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
7-
gather_generation_logits=True # Required. TODO: Acutal API TBD.
9+
model="/scratch/llm-models/llama-models-v2/TinyLlama-1.1B-Chat-v1.0",
10+
gather_generation_logits=True, # Required. TODO: Acutal API TBD.
11+
orchestrator_type="ray"
812
)
913

14+
# llm = TrtLLM(
15+
# model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
16+
# gather_generation_logits=True
17+
# )
18+
1019
# Sample prompts.
1120
prompts = [
1221
"Hello, my name is",
@@ -19,8 +28,8 @@ def main():
1928
# - Without return_generation_logits=True: Returns top-K tokens (sampled token NOT guaranteed)
2029
sampling_params = SamplingParams(
2130
max_tokens=10,
22-
temperature=0.7,
23-
top_p=0.95,
31+
# temperature=0.7,
32+
# top_p=0.95,
2433
logprobs=1,
2534
return_generation_logits=True,
2635
)
@@ -31,6 +40,14 @@ def main():
3140
print(f"Generated text: {output.outputs[0].text!r}")
3241
print(f"Generated token IDs: {output.outputs[0].token_ids}")
3342

43+
if output.outputs[0].generation_logits is not None:
44+
logits = output.outputs[0].generation_logits
45+
46+
# sanity check on sampled logits
47+
num_logits = logits.shape[0]
48+
sampled_logits = [logits[i, token_id].item() for i, token_id in enumerate(output.outputs[0].token_ids[:num_logits])]
49+
print(f"Logits of sampled tokens: {sampled_logits}")
50+
3451
if output.outputs[0].logprobs:
3552
print(f"\nLogprobs for each generated token:")
3653
for i, (token_id, token_logprobs) in enumerate(
@@ -39,7 +56,7 @@ def main():
3956
print(f"\n Token {i}: ID={token_id}, Text={llm.tokenizer.decode([token_id])!r}")
4057

4158
# TODO. move to proper test
42-
assert len(token_logprobs) == 1
59+
# assert len(token_logprobs) == 1
4360
assert token_id in token_logprobs, f"Sampled token {token_id} not in logprobs dict."
4461

4562
for tid, logprob_obj in token_logprobs.items():

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -756,9 +756,12 @@ def handle_logprobs(
756756
if request.py_return_generation_logits:
757757
generation_logits_storage = request.py_result._generation_logits
758758
if generation_logits_storage and generation_logits_storage._storage is not None:
759-
# Compute log_softmax to get logprobs for the sampled token
760-
# Iinternal storage tensor: [seq_length, beam_width, vocab_size]
761-
logits_for_step = generation_logits_storage._storage[step] # [beam_width, vocab_size]
759+
# Internal storage tensor: [seq_length, beam_width, vocab_size]
760+
# Calculate absolute step index in the generation sequence
761+
num_generated_tokens = len(request.get_tokens(beam)) - request.py_prompt_len
762+
absolute_step = num_generated_tokens - count + step
763+
764+
logits_for_step = generation_logits_storage._storage[absolute_step] # [beam_width, vocab_size]
762765
logprobs_for_step = F.log_softmax(logits_for_step[beam].float(), dim=-1)
763766
sampled_logprob = logprobs_for_step[sampled_token].item()
764767

0 commit comments

Comments
 (0)