Skip to content

Commit a84e873

Browse files
Fix greedy search sampling batch return value mismatch in flashinfer sampling.
Signed-off-by: Wangshanshan <[email protected]>
1 parent d299b7a commit a84e873

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _sample_greedy_with_probs(
137137
group_logit_indices: Optional[torch.Tensor],
138138
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
139139
probs = self._prepare_probs_with_temperature(logits, group_logit_indices, None)
140-
new_tokens, _ = greedy_search_sampling_batch(probs, return_probs=False)
140+
new_tokens, _, _ = greedy_search_sampling_batch(probs, return_probs=False)
141141
return new_tokens, probs
142142

143143
@classmethod
@@ -370,7 +370,8 @@ def sample(
370370
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
371371
if group_logit_indices is not None:
372372
logits = logits[group_logit_indices]
373-
return greedy_search_sampling_batch(logits, return_probs=False)
373+
tokens, probs, _ = greedy_search_sampling_batch(logits, return_probs=False)
374+
return tokens, probs
374375

375376
class TopKTopPSampleOnly(StrategyImplSampleOnly):
376377
def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch.Tensor):

0 commit comments

Comments
 (0)