File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed
tensorrt_llm/_torch/speculative Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -441,7 +441,18 @@ def process_decoded_tokens(
441441 req .py_request_id ] = [0 ] * self .max_total_draft_tokens
442442 self .draft_tokens_accumulator [req .py_request_id ][
443443 draft_position - 1 ] = req .get_last_tokens (0 )
444- target_model_req .py_draft_logits = req .py_result .generation_logits # forwards Nones
444+
445+ generation_logits = req .py_result .generation_logits # forwards Nones
446+ if generation_logits is not None :
447+ # generation_logits returns [beam_width, seq_length, vocab_size]
448+ beam_width = generation_logits .size (0 )
449+ assert beam_width == 1 , f"expected beam_width=1, got { beam_width } "
450+ generation_logits .squeeze_ (0 )
451+ # Transfer to CUDA if needed (chunked LogitsStorage stores on CPU)
452+ if generation_logits .device .type == 'cpu' :
453+ generation_logits = generation_logits .to ('cuda' ,
454+ non_blocking = True )
455+ target_model_req .py_draft_logits = generation_logits
445456
446457 if req .state != LlmRequestState .GENERATION_COMPLETE and draft_position < target_model_req .py_draft_pages_allocated :
447458 new_requests .append (req )
You can’t perform that action at this time.
0 commit comments