Skip to content

Commit 99ba723

Browse files
authored
[None][fix] logits device and shape issues in dynamic draft path (#9079)
Signed-off-by: jellysnack <[email protected]>
1 parent 782dfca commit 99ba723

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,18 @@ def process_decoded_tokens(
441441
req.py_request_id] = [0] * self.max_total_draft_tokens
442442
self.draft_tokens_accumulator[req.py_request_id][
443443
draft_position - 1] = req.get_last_tokens(0)
444-
target_model_req.py_draft_logits = req.py_result.generation_logits # forwards Nones
444+
445+
generation_logits = req.py_result.generation_logits # forwards Nones
446+
if generation_logits is not None:
447+
# generation_logits returns [beam_width, seq_length, vocab_size]
448+
beam_width = generation_logits.size(0)
449+
assert beam_width == 1, f"expected beam_width=1, got {beam_width}"
450+
generation_logits.squeeze_(0)
451+
# Transfer to CUDA if needed (chunked LogitsStorage stores on CPU)
452+
if generation_logits.device.type == 'cpu':
453+
generation_logits = generation_logits.to('cuda',
454+
non_blocking=True)
455+
target_model_req.py_draft_logits = generation_logits
445456

446457
if req.state != LlmRequestState.GENERATION_COMPLETE and draft_position < target_model_req.py_draft_pages_allocated:
447458
new_requests.append(req)

0 commit comments

Comments
 (0)