Skip to content

Commit cdd8710

Browse files
committed
vllm : add logits extraction example
1 parent 5bb5669 commit cdd8710

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

fundamentals/vllm/src/logits.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from vllm import LLM, SamplingParams
2+
import torch
3+
4+
class LogitsPrinter:
5+
def __call__(self, token_ids, logits):
6+
print(f"\n=== Logits at step {len(token_ids)} ===")
7+
print(f"Logits shape: {logits.shape}")
8+
print(f"Min logit: {logits.min().item():.4f}")
9+
print(f"Max logit: {logits.max().item():.4f}")
10+
11+
# Get top-k tokens
12+
top_k = 5
13+
top_values, top_indices = torch.topk(logits, top_k)
14+
print(f"\nTop {top_k} tokens:")
15+
for idx, (val, token_id) in enumerate(zip(top_values, top_indices)):
16+
print(f" {idx+1}. Token {token_id}: logit={val.item():.4f}")
17+
18+
return logits # Must return logits unchanged
19+
20+
llm = LLM(model="path/to-model", trust_remote_code=True)
21+
22+
sampling_params = SamplingParams(
23+
temperature=0.8,
24+
max_tokens=10,
25+
logits_processors=[LogitsPrinter()]
26+
)
27+
28+
outputs = llm.generate(["Hello, my name is"], sampling_params)
29+

0 commit comments

Comments
 (0)