Skip to content

Commit 5df907b

Browse files
authored
[https://nvbugs/5590408][fix] Fallback to greedy sampling in two-model overlap scheduler (#9321)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent f2ebaf2 commit 5df907b

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,11 @@ def drafting_loop_wrapper(model):
414414
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
415415
model_engine_max_seq_len += spec_config.max_total_draft_tokens
416416

417+
if has_draft_model_engine and not llm_args.disable_overlap_scheduler:
418+
logger.warning(
419+
"Overlap scheduler is enabled for two-model speculative decoding. Rejection sampling will fallback to greedy sampling."
420+
)
421+
417422
max_seq_len = model_engine_max_seq_len
418423
max_num_tokens = model_engine.max_num_tokens
419424
sparse_attention_config = model_engine.sparse_attention_config

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,10 @@ def process_static_draft_outputs(self, outputs: dict[str, torch.Tensor]
625625
target_model_req.py_draft_tokens.append(
626626
draft_tokens_host[token_idx][req_idx])
627627
py_draft_logits.append(draft_logits[token_idx][req_idx])
628-
target_model_req.py_draft_logits = torch.stack(py_draft_logits)
628+
629+
# The overlap scheduler doesn't support rejection sampling yet, so we don't update the py_draft_logits to get it fallback to greedy sampling.
630+
if self.disable_overlap_scheduler:
631+
target_model_req.py_draft_logits = torch.stack(py_draft_logits)
629632

630633
def process_dynamic_draft_outputs(
631634
self,

0 commit comments

Comments
 (0)