We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4ec56fc commit d299b7aCopy full SHA for d299b7a
tensorrt_llm/_torch/auto_deploy/shim/demollm.py
@@ -235,11 +235,11 @@ def _sample(
235
logits_shape = logits.shape
236
logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits
237
if isinstance(sampling_params.top_k, int) and sampling_params.top_k > 1:
238
- idx_next, probs = top_k_sampling_batch(
+ idx_next, probs, _ = top_k_sampling_batch(
239
logits, top_k=sampling_params.top_k, temperature=1.0
240
)
241
else:
242
- idx_next, probs = greedy_search_sampling_batch(logits)
+ idx_next, probs, _ = greedy_search_sampling_batch(logits)
243
idx_next = idx_next.view(logits_shape[:-1])
244
return idx_next, probs
245
0 commit comments