Skip to content

Commit a7aaf50

Browse files
authored
[TRTLLM-8084][feat] Enhance the overlap shceduler for two-model spec decoding (#8706)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 121140c commit a7aaf50

File tree

6 files changed

+620
-221
lines changed

6 files changed

+620
-221
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 266 additions & 39 deletions
Large diffs are not rendered by default.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 124 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from ..models.modeling_utils import DecoderModelForCausalLM
4040
from ..modules.decoder_layer import DecoderLayer
4141
from ..speculative.drafter import Drafter
42+
from ..speculative.mtp import SampleStateTensorsMTP
4243
from ..speculative.speculation_gate import SpeculationGate
4344
from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem
4445
from .guided_decoder import GuidedDecoder
@@ -275,7 +276,7 @@ def __init__(self,
275276
if self.dist.pp_size > 1:
276277
self.event_loop = self._executor_loop_pp
277278
else:
278-
self.event_loop = self._executor_loop if disable_overlap_scheduler else self._executor_loop_overlap
279+
self.event_loop = self._executor_loop if self.disable_overlap_scheduler else self._executor_loop_overlap
279280
if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"):
280281
self.event_loop = trace_func(self.event_loop)
281282

@@ -1059,14 +1060,11 @@ def _prepare_and_schedule_batch(self):
10591060
0
10601061
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10611062

1062-
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
1063-
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
1064-
if not self.has_previous_draft_tokens:
1065-
# If speculation is off, this function sets py_draft_tokens to []
1066-
# for all active requests. If it's on, we initialize py_draft_tokens
1067-
# with dummy draft tokens to make the scheduler aware of the fact
1068-
# that speculation is about to happen.
1069-
self._prepare_draft_requests()
1063+
# If speculation is off, this function sets py_draft_tokens to []
1064+
# for all active requests. If it's on, we initialize py_draft_tokens
1065+
# with dummy draft tokens to make the scheduler aware of the fact
1066+
# that speculation is about to happen.
1067+
self._prepare_draft_requests()
10701068

10711069
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
10721070
)
@@ -1315,6 +1313,8 @@ def _executor_loop_overlap(self):
13151313
with self._profiler() as profile_step:
13161314
iter_start_time = time.time()
13171315
iter_stats = None
1316+
target_inputs = None
1317+
previous_tensors_device = None
13181318
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
13191319
while True:
13201320
profile_step()
@@ -1395,31 +1395,29 @@ def _executor_loop_overlap(self):
13951395
self.guided_decoder.init_disagg_gen_requests()
13961396

13971397
previous_tensors = self.previous_batch and self.previous_batch.sample_state
1398-
target_inputs = None
1399-
draft_outputs = None
14001398
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
14011399
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
14021400
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
14031401
use_previous_draft_tokens = self.has_previous_draft_tokens
14041402
if self.drafter is not None and (self.use_spec_decode or
14051403
use_previous_draft_tokens):
1406-
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
1407-
scheduled_batch, previous_tensors)
1404+
target_inputs = self._handle_speculative_decoding(
1405+
scheduled_batch, previous_tensors,
1406+
previous_tensors_device)
14081407

14091408
# Use the draft_model's outputs if we've launched the draft model.
14101409
# Otherwise, use the previous batch's outputs.
1411-
if target_inputs is not None or use_previous_draft_tokens:
1410+
if (target_inputs is not None
1411+
and target_inputs.next_draft_tokens
1412+
is not None) or use_previous_draft_tokens:
14121413
previous_tensors_device = target_inputs
14131414
else:
14141415
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
14151416

14161417
batch_outputs = self._forward_step(scheduled_batch,
14171418
previous_tensors_device)
14181419

1419-
if target_inputs is not None:
1420-
self._process_draft_results(scheduled_batch,
1421-
draft_outputs, draft_batch)
1422-
elif self.previous_batch is not None and not use_previous_draft_tokens:
1420+
if self.previous_batch is not None:
14231421
self._update_requests(self.previous_batch.sample_state)
14241422

14251423
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
@@ -1434,6 +1432,10 @@ def _executor_loop_overlap(self):
14341432
(req, block_id,
14351433
self.ctx_in_transmission_counter))
14361434

1435+
if self.drafter is not None and self.use_spec_decode:
1436+
# Cleanup previous draft resources used in the draft model
1437+
self.drafter.cleanup_previous_draft_resources()
1438+
14371439
if self.guided_decoder is not None:
14381440
# add_batch must be called again to have updated new tokens.
14391441
self.guided_decoder.add_batch(scheduled_batch)
@@ -1468,6 +1470,94 @@ def _executor_loop_overlap(self):
14681470

14691471
self._kv_connector_terminate_requests()
14701472

1473+
def _accept_draft_tokens(
1474+
self, scheduled_batch: ScheduledRequests,
1475+
target_outputs: SampleStateTensors,
1476+
target_inputs: Optional[SampleStateTensors]
1477+
) -> Tuple[SampleStateTensorsMTP, Optional[torch.Tensor]]:
1478+
"""
1479+
Prepare target device inputs after computing draft token acceptance.
1480+
1481+
This function:
1482+
1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
1483+
2. If no draft tokens: directly uses the first sampled token
1484+
3. Creates new_tokens by extracting accepted tokens per request
1485+
1486+
Args:
1487+
scheduled_batch: The scheduled requests
1488+
target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
1489+
or [1, batch_size, beam_width] if no draft tokens
1490+
target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
1491+
Returns:
1492+
Tuple of:
1493+
- SampleStateTensorsMTP with new_tokens set to accepted tokens,
1494+
new_tokens_lens and next_draft_tokens set to None
1495+
- num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
1496+
or None if no draft tokens
1497+
"""
1498+
has_draft_tokens = target_inputs is not None and isinstance(
1499+
target_inputs, SampleStateTensorsMTP
1500+
) and target_inputs.next_draft_tokens is not None
1501+
target_tokens = target_outputs.new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
1502+
new_tokens = torch.zeros_like(target_tokens)
1503+
1504+
# Squeeze the beam dimension (beam_width=1 for greedy or single beam)
1505+
target_tokens = target_tokens.squeeze(
1506+
-1) # [max_draft_len + 1, batch_size] or [1, batch_size]
1507+
1508+
batch_size = target_tokens.shape[1]
1509+
device = target_tokens.device
1510+
# Compute number of accepted tokens per request
1511+
num_accepted_tokens = torch.zeros(batch_size,
1512+
dtype=torch.int32,
1513+
device=device)
1514+
1515+
if has_draft_tokens:
1516+
# Draft tokens exist, compute acceptance
1517+
draft_tokens = target_inputs.next_draft_tokens # [batch_size, max_draft_len]
1518+
max_draft_len = draft_tokens.shape[1]
1519+
1520+
# Compute number of accepted tokens per request
1521+
# Generation requests: compare with draft tokens to find acceptance
1522+
num_contexts = len(scheduled_batch.context_requests)
1523+
if batch_size > num_contexts:
1524+
# Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
1525+
gen_target_tokens = target_tokens[:,
1526+
num_contexts:].T # [num_gens, max_draft_len + 1]
1527+
1528+
# Compare draft tokens with target tokens to find acceptance
1529+
# Use cumprod to find the first rejection point
1530+
draft_tokens_gen = draft_tokens[
1531+
num_contexts:, :].int() # [num_gens, max_draft_len]
1532+
num_accepted_tokens[num_contexts:] += torch.cumprod(
1533+
(draft_tokens_gen == gen_target_tokens[:, :max_draft_len]
1534+
).int(),
1535+
dim=-1).sum(dim=1)
1536+
1537+
# Vectorized extraction using advanced indexing (no GPU-CPU sync)
1538+
# Use num_accepted_tokens as indices to gather the right tokens
1539+
batch_indices = torch.arange(batch_size, device=device)
1540+
new_tokens[0, :, 0] = target_tokens[num_accepted_tokens,
1541+
batch_indices]
1542+
else:
1543+
# No draft tokens to accept, just use the first (and only) sampled token
1544+
batch_indices = torch.arange(batch_size, device=device)
1545+
new_tokens[0, :, 0] = target_tokens[0, batch_indices]
1546+
1547+
# Create the updated SampleStateTensorsMTP
1548+
# new_tokens_lens and next_draft_tokens are left as None
1549+
result_tensors = SampleStateTensorsMTP(
1550+
new_tokens=new_tokens,
1551+
log_probs=target_outputs.log_probs,
1552+
new_tokens_lens=None,
1553+
next_draft_tokens=None)
1554+
1555+
# Copy logits if available
1556+
if hasattr(target_outputs, 'logits'):
1557+
result_tensors.logits = target_outputs.logits
1558+
1559+
return result_tensors, num_accepted_tokens
1560+
14711561
def _process_previous_batch(self):
14721562
if self.kv_cache_transceiver and self.previous_batch.ctx_transmission_reqs:
14731563
for req in self.previous_batch.ctx_transmission_reqs:
@@ -2364,7 +2454,8 @@ def _remove_inflight_ids(self, scheduled_requests):
23642454
for req in scheduled_requests.all_requests():
23652455
self.inflight_req_ids.erase(req.request_id)
23662456

2367-
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
2457+
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors,
2458+
target_inputs):
23682459
with request_context(is_draft=self.draft_model_engine is not None,
23692460
scheduled_requests=scheduled_batch):
23702461
# Do an early checking to see if we need to forward the draft model.
@@ -2374,20 +2465,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23742465
self.previous_batch is not None and self.use_spec_decode
23752466
and self.drafter.should_forward_draft_model(scheduled_batch))
23762467

2377-
if has_draft_batch or self.has_previous_draft_tokens:
2378-
self._update_requests(self.previous_batch.sample_state)
2379-
if self.has_previous_draft_tokens:
2380-
self._prepare_draft_requests()
2468+
new_target_inputs = None
2469+
if has_draft_batch:
2470+
target_outputs = self.previous_batch.sample_state and self.previous_batch.sample_state.device
2471+
assert target_outputs is not None, "target_outputs should not be None"
2472+
new_target_inputs, num_accepted_tokens_device = self._accept_draft_tokens(
2473+
scheduled_batch=scheduled_batch,
2474+
target_inputs=target_inputs,
2475+
target_outputs=target_outputs)
23812476

23822477
if has_draft_batch:
2383-
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
2478+
self.drafter.generate_draft_tokens_with_overlap(
23842479
scheduled_batch, self.resource_manager,
2385-
previous_tensors.device if previous_tensors else None)
2480+
previous_tensors.device if previous_tensors else None,
2481+
new_target_inputs, num_accepted_tokens_device)
23862482

2387-
self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None
2483+
# Pad draft tokens to the max draft length for CUDA graph compatibility
2484+
self.has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs.next_draft_tokens is not None
23882485
else:
23892486
self.has_previous_draft_tokens = False
2390-
target_inputs, draft_outputs, draft_batch = None, None, None
23912487
# We are not running the draft model. Remove the draft tokens and turn off spec
23922488
# decode so that the requests get handled correctly.
23932489
# One corner case: when we have at least one context request, we have to keep spec
@@ -2400,34 +2496,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
24002496
for request in scheduled_batch.all_requests():
24012497
request.py_draft_tokens = []
24022498

2403-
return target_inputs, draft_outputs, draft_batch
2404-
2405-
def _process_draft_results(self, scheduled_batch, draft_outputs,
2406-
draft_batch):
2407-
"""
2408-
Append the draft tokens to the target requests, and clean up the draft resources.
2409-
"""
2410-
with request_context(is_draft=self.draft_model_engine is not None,
2411-
scheduled_requests=scheduled_batch):
2412-
req_id_to_old_request = {
2413-
req.py_request_id: req
2414-
for req in scheduled_batch.all_requests()
2415-
}
2416-
2417-
if self.drafter.use_static_draft_loop:
2418-
self.drafter.process_static_draft_outputs(
2419-
draft_outputs, draft_batch, req_id_to_old_request)
2420-
elif draft_outputs is not None:
2421-
self.drafter.process_dynamic_draft_outputs(
2422-
draft_outputs, req_id_to_old_request)
2423-
2424-
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2425-
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
2426-
# add_batch must be called again to restore to target requests with updated draft tokens.
2427-
if self.guided_decoder is not None:
2428-
self.guided_decoder.add_batch(scheduled_batch)
2429-
if hasattr(self.drafter, "guided_decoder"):
2430-
self.guided_decoder.rollback_draft_tokens()
2499+
return new_target_inputs
24312500

24322501
def reset_prefix_cache(self):
24332502
self.kv_cache_manager.reset_reuse_state()

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from tensorrt_llm.quantization import QuantAlgo
2626

2727
from ..attention_backend.interface import AttentionRuntimeFeatures
28+
from ..attention_backend.trtllm import TrtllmAttention
2829
from ..distributed import MPIDist, TorchDist
2930
from ..speculative import (get_num_extra_kv_tokens, get_spec_drafter,
3031
get_spec_resource_manager)
@@ -389,6 +390,16 @@ def drafting_loop_wrapper(model):
389390
else:
390391
draft_model_engine = None
391392

393+
# TODO: Overlap scheduler is not supported for below cases:
394+
# 1. non-CDL is used
395+
# 2. non-TrtllmAttention attention backend is used
396+
if has_draft_model_engine and (not use_chain_drafter or not issubclass(
397+
draft_model_engine.attn_backend, TrtllmAttention)):
398+
logger.warning(
399+
"Overlap scheduler is not supported for non-CDL or non-TrtllmAttention backend."
400+
)
401+
llm_args.disable_overlap_scheduler = True
402+
392403
# PyTorchModelEngine modifies these fields, update them
393404
model_engine_max_seq_len = model_engine.max_seq_len
394405
net_max_seq_len = model_engine_max_seq_len

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,9 @@ def _group_requests_by_strategy_key(
280280
)
281281
for req_index, req in enumerate(requests):
282282
strategy = _request_strategy(req, vocab_size=vocab_size)
283-
speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY
283+
# In the overlap path, py_draft_logits is not updated yet,
284+
# so we use get_draft_token_length() for the checking.
285+
speculation_needs_probs = get_draft_token_length(req) > 0 and strategy is not GREEDY
284286
strategy_key = strategy_to_key(strategy, speculation_needs_probs)
285287
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
286288
group_dict_entry[0].append(req_index)

0 commit comments

Comments
 (0)