1+ import torch
2+
13from tensorrt_llm import LLM , SamplingParams
4+ from tensorrt_llm ._tensorrt_engine import LLM as TrtLLM
25
36
47def main ():
58 llm = LLM (
6- model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
7- gather_generation_logits = True # Required. TODO: Acutal API TBD.
9+ model = "/scratch/llm-models/llama-models-v2/TinyLlama-1.1B-Chat-v1.0" ,
10+ gather_generation_logits = True , # Required. TODO: Acutal API TBD.
11+ orchestrator_type = "ray"
812 )
913
14+ # llm = TrtLLM(
15+ # model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
16+ # gather_generation_logits=True
17+ # )
18+
1019 # Sample prompts.
1120 prompts = [
1221 "Hello, my name is" ,
@@ -19,8 +28,8 @@ def main():
1928 # - Without return_generation_logits=True: Returns top-K tokens (sampled token NOT guaranteed)
2029 sampling_params = SamplingParams (
2130 max_tokens = 10 ,
22- temperature = 0.7 ,
23- top_p = 0.95 ,
31+ # temperature=0.7,
32+ # top_p=0.95,
2433 logprobs = 1 ,
2534 return_generation_logits = True ,
2635 )
@@ -31,6 +40,14 @@ def main():
3140 print (f"Generated text: { output .outputs [0 ].text !r} " )
3241 print (f"Generated token IDs: { output .outputs [0 ].token_ids } " )
3342
43+ if output .outputs [0 ].generation_logits is not None :
44+ logits = output .outputs [0 ].generation_logits
45+
46+ # sanity check on sampled logits
47+ num_logits = logits .shape [0 ]
48+ sampled_logits = [logits [i , token_id ].item () for i , token_id in enumerate (output .outputs [0 ].token_ids [:num_logits ])]
49+ print (f"Logits of sampled tokens: { sampled_logits } " )
50+
3451 if output .outputs [0 ].logprobs :
3552 print (f"\n Logprobs for each generated token:" )
3653 for i , (token_id , token_logprobs ) in enumerate (
@@ -39,7 +56,7 @@ def main():
3956 print (f"\n Token { i } : ID={ token_id } , Text={ llm .tokenizer .decode ([token_id ])!r} " )
4057
4158 # TODO. move to proper test
42- assert len (token_logprobs ) == 1
59+ # assert len(token_logprobs) == 1
4360 assert token_id in token_logprobs , f"Sampled token { token_id } not in logprobs dict."
4461
4562 for tid , logprob_obj in token_logprobs .items ():
0 commit comments