1+ import torch
2+
13from tensorrt_llm import LLM , SamplingParams
4+ from tensorrt_llm ._tensorrt_engine import LLM as TrtLLM
25
36
47def main ():
58 llm = LLM (
6- model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
7- gather_generation_logits = True # Required. TODO: Acutal API TBD.
9+ model = "/scratch/llm-models/llama-models-v2/TinyLlama-1.1B-Chat-v1.0" ,
10+ gather_generation_logits = True , # Required. TODO: Acutal API TBD.
11+ orchestrator_type = "ray"
812 )
913
14+ # llm = TrtLLM(
15+ # model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
16+ # gather_generation_logits=True
17+ # )
18+
1019 # Sample prompts.
1120 prompts = [
1221 "Hello, my name is" ,
@@ -31,6 +40,12 @@ def main():
3140 print (f"Generated text: { output .outputs [0 ].text !r} " )
3241 print (f"Generated token IDs: { output .outputs [0 ].token_ids } " )
3342
43+ if output .outputs [0 ].generation_logits is not None :
44+ logits = output .outputs [0 ].generation_logits
45+ # Extract logits for the sampled tokens only
46+ sampled_logits = [logits [i , token_id ].item () for i , token_id in enumerate (output .outputs [0 ].token_ids )]
47+ print (f"Logits of sampled tokens: { sampled_logits } " )
48+
3449 if output .outputs [0 ].logprobs :
3550 print (f"\n Logprobs for each generated token:" )
3651 for i , (token_id , token_logprobs ) in enumerate (
@@ -39,7 +54,7 @@ def main():
3954 print (f"\n Token { i } : ID={ token_id } , Text={ llm .tokenizer .decode ([token_id ])!r} " )
4055
4156 # TODO. move to proper test
42- assert len (token_logprobs ) == 1
57+ # assert len(token_logprobs) == 1
4358 assert token_id in token_logprobs , f"Sampled token { token_id } not in logprobs dict."
4459
4560 for tid , logprob_obj in token_logprobs .items ():
0 commit comments