@@ -745,15 +745,44 @@ def handle_logprobs(
745745 topk_log_probs_vals = request .py_topk_logprobs_vals [:count ]
746746 topk_log_probs_indices = request .py_topk_logprobs_indices [:count ]
747747
748- token_log_probs = [
749- {
750- token : Logprob (logprob = logprob , rank = rank + 1 )
751- for rank , (token , logprob ) in enumerate (
752- zip (topk_token .tolist (), topk_logprob .tolist ())
753- )
754- }
755- for topk_token , topk_logprob in zip (topk_log_probs_indices , topk_log_probs_vals )
756- ]
748+ sampled_tokens = request .get_tokens (beam )[- count :]
749+
750+ token_log_probs = []
751+ for step , (topk_token , topk_logprob ) in enumerate (zip (topk_log_probs_indices , topk_log_probs_vals )):
752+ sampled_token = sampled_tokens [step ]
753+
754+ # TODO. WAR: If both gather_generation_logits and return_generation_logits are set,
755+ # return ONLY the sampled token's logprob (not top-K).
756+ if request .py_return_generation_logits :
757+ generation_logits_storage = request .py_result ._generation_logits
758+ if generation_logits_storage and generation_logits_storage ._storage is not None :
759+ # Compute log_softmax to get logprobs for the sampled token
760+ # Iinternal storage tensor: [seq_length, beam_width, vocab_size]
761+ logits_for_step = generation_logits_storage ._storage [step ] # [beam_width, vocab_size]
762+ logprobs_for_step = F .log_softmax (logits_for_step [beam ].float (), dim = - 1 )
763+ sampled_logprob = logprobs_for_step [sampled_token ].item ()
764+
765+ rank = (logprobs_for_step > sampled_logprob ).sum ().item () + 1
766+
767+ step_dict = {sampled_token : Logprob (logprob = sampled_logprob , rank = rank )}
768+ else :
769+ step_dict = {
770+ token : Logprob (logprob = logprob , rank = rank + 1 )
771+ for rank , (token , logprob ) in enumerate (
772+ zip (topk_token .tolist (), topk_logprob .tolist ())
773+ )
774+ }
775+ else :
776+ # Original behavior: return top-K
777+ step_dict = {
778+ token : Logprob (logprob = logprob , rank = rank + 1 )
779+ for rank , (token , logprob ) in enumerate (
780+ zip (topk_token .tolist (), topk_logprob .tolist ())
781+ )
782+ }
783+
784+ token_log_probs .append (step_dict )
785+
757786 assert beam == 0 , (
758787 "The following call relies on beam_width to be 1 - hence the list with a single element"
759788 )
0 commit comments