3939from ..models .modeling_utils import DecoderModelForCausalLM
4040from ..modules .decoder_layer import DecoderLayer
4141from ..speculative .drafter import Drafter
42+ from ..speculative .mtp import SampleStateTensorsMTP
4243from ..speculative .speculation_gate import SpeculationGate
4344from .executor_request_queue import ExecutorRequestQueue , RequestQueueItem
4445from .guided_decoder import GuidedDecoder
@@ -275,7 +276,7 @@ def __init__(self,
275276 if self .dist .pp_size > 1 :
276277 self .event_loop = self ._executor_loop_pp
277278 else :
278- self .event_loop = self ._executor_loop if disable_overlap_scheduler else self ._executor_loop_overlap
279+ self .event_loop = self ._executor_loop if self . disable_overlap_scheduler else self ._executor_loop_overlap
279280 if is_trace_enabled ("TLLM_TRACE_EXECUTOR_LOOP" ):
280281 self .event_loop = trace_func (self .event_loop )
281282
@@ -1059,14 +1060,11 @@ def _prepare_and_schedule_batch(self):
10591060 0
10601061 ] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
10611062
1062- # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
1063- # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
1064- if not self .has_previous_draft_tokens :
1065- # If speculation is off, this function sets py_draft_tokens to []
1066- # for all active requests. If it's on, we initialize py_draft_tokens
1067- # with dummy draft tokens to make the scheduler aware of the fact
1068- # that speculation is about to happen.
1069- self ._prepare_draft_requests ()
1063+ # If speculation is off, this function sets py_draft_tokens to []
1064+ # for all active requests. If it's on, we initialize py_draft_tokens
1065+ # with dummy draft tokens to make the scheduler aware of the fact
1066+ # that speculation is about to happen.
1067+ self ._prepare_draft_requests ()
10701068
10711069 scheduled_batch , fitting_disagg_gen_init_requests , num_fitting_reqs = self ._schedule (
10721070 )
@@ -1315,6 +1313,8 @@ def _executor_loop_overlap(self):
13151313 with self ._profiler () as profile_step :
13161314 iter_start_time = time .time ()
13171315 iter_stats = None
1316+ target_inputs = None
1317+ previous_tensors_device = None
13181318 can_forward = False if self .benchmark_req_queues_size > 0 and self .kv_cache_transceiver else True
13191319 while True :
13201320 profile_step ()
@@ -1395,31 +1395,29 @@ def _executor_loop_overlap(self):
13951395 self .guided_decoder .init_disagg_gen_requests ()
13961396
13971397 previous_tensors = self .previous_batch and self .previous_batch .sample_state
1398- target_inputs = None
1399- draft_outputs = None
14001398 # If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
14011399 # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
14021400 # so we'll set the target model's input to None and skip updating the target requests after target model forward.
14031401 use_previous_draft_tokens = self .has_previous_draft_tokens
14041402 if self .drafter is not None and (self .use_spec_decode or
14051403 use_previous_draft_tokens ):
1406- target_inputs , draft_outputs , draft_batch = self ._handle_speculative_decoding (
1407- scheduled_batch , previous_tensors )
1404+ target_inputs = self ._handle_speculative_decoding (
1405+ scheduled_batch , previous_tensors ,
1406+ previous_tensors_device )
14081407
14091408 # Use the draft_model's outputs if we've launched the draft model.
14101409 # Otherwise, use the previous batch's outputs.
1411- if target_inputs is not None or use_previous_draft_tokens :
1410+ if (target_inputs is not None
1411+ and target_inputs .next_draft_tokens
1412+ is not None ) or use_previous_draft_tokens :
14121413 previous_tensors_device = target_inputs
14131414 else :
14141415 previous_tensors_device = self .previous_batch and self .previous_batch .sample_state and self .previous_batch .sample_state .device
14151416
14161417 batch_outputs = self ._forward_step (scheduled_batch ,
14171418 previous_tensors_device )
14181419
1419- if target_inputs is not None :
1420- self ._process_draft_results (scheduled_batch ,
1421- draft_outputs , draft_batch )
1422- elif self .previous_batch is not None and not use_previous_draft_tokens :
1420+ if self .previous_batch is not None :
14231421 self ._update_requests (self .previous_batch .sample_state )
14241422
14251423 if self .block_reuse_enabled and not self .kv_cache_manager .is_vswa and self .kv_cache_transceiver :
@@ -1434,6 +1432,10 @@ def _executor_loop_overlap(self):
14341432 (req , block_id ,
14351433 self .ctx_in_transmission_counter ))
14361434
1435+ if self .drafter is not None and self .use_spec_decode :
1436+ # Cleanup previous draft resources used in the draft model
1437+ self .drafter .cleanup_previous_draft_resources ()
1438+
14371439 if self .guided_decoder is not None :
14381440 # add_batch must be called again to have updated new tokens.
14391441 self .guided_decoder .add_batch (scheduled_batch )
@@ -1468,6 +1470,94 @@ def _executor_loop_overlap(self):
14681470
14691471 self ._kv_connector_terminate_requests ()
14701472
1473+ def _accept_draft_tokens (
1474+ self , scheduled_batch : ScheduledRequests ,
1475+ target_outputs : SampleStateTensors ,
1476+ target_inputs : Optional [SampleStateTensors ]
1477+ ) -> Tuple [SampleStateTensorsMTP , Optional [torch .Tensor ]]:
1478+ """
1479+ Prepare target device inputs after computing draft token acceptance.
1480+
1481+ This function:
1482+ 1. If draft tokens exist: compares sampled tokens with draft tokens to compute acceptance
1483+ 2. If no draft tokens: directly uses the first sampled token
1484+ 3. Creates new_tokens by extracting accepted tokens per request
1485+
1486+ Args:
1487+ scheduled_batch: The scheduled requests
1488+ target_outputs: Contains new_tokens [max_draft_len + 1, batch_size, beam_width]
1489+ or [1, batch_size, beam_width] if no draft tokens
1490+ target_inputs: Contains next_draft_tokens [batch_size, max_draft_len]
1491+ Returns:
1492+ Tuple of:
1493+ - SampleStateTensorsMTP with new_tokens set to accepted tokens,
1494+ new_tokens_lens and next_draft_tokens set to None
1495+ - num_accepted_tokens: [batch_size] tensor with acceptance counts per request,
1496+ or None if no draft tokens
1497+ """
1498+ has_draft_tokens = target_inputs is not None and isinstance (
1499+ target_inputs , SampleStateTensorsMTP
1500+ ) and target_inputs .next_draft_tokens is not None
1501+ target_tokens = target_outputs .new_tokens # [max_draft_len + 1, batch_size, beam_width] or [1, batch_size, beam_width]
1502+ new_tokens = torch .zeros_like (target_tokens )
1503+
1504+ # Squeeze the beam dimension (beam_width=1 for greedy or single beam)
1505+ target_tokens = target_tokens .squeeze (
1506+ - 1 ) # [max_draft_len + 1, batch_size] or [1, batch_size]
1507+
1508+ batch_size = target_tokens .shape [1 ]
1509+ device = target_tokens .device
1510+ # Compute number of accepted tokens per request
1511+ num_accepted_tokens = torch .zeros (batch_size ,
1512+ dtype = torch .int32 ,
1513+ device = device )
1514+
1515+ if has_draft_tokens :
1516+ # Draft tokens exist, compute acceptance
1517+ draft_tokens = target_inputs .next_draft_tokens # [batch_size, max_draft_len]
1518+ max_draft_len = draft_tokens .shape [1 ]
1519+
1520+ # Compute number of accepted tokens per request
1521+ # Generation requests: compare with draft tokens to find acceptance
1522+ num_contexts = len (scheduled_batch .context_requests )
1523+ if batch_size > num_contexts :
1524+ # Use .T to transpose: [max_draft_len + 1, num_gens] -> [num_gens, max_draft_len + 1]
1525+ gen_target_tokens = target_tokens [:,
1526+ num_contexts :].T # [num_gens, max_draft_len + 1]
1527+
1528+ # Compare draft tokens with target tokens to find acceptance
1529+ # Use cumprod to find the first rejection point
1530+ draft_tokens_gen = draft_tokens [
1531+ num_contexts :, :].int () # [num_gens, max_draft_len]
1532+ num_accepted_tokens [num_contexts :] += torch .cumprod (
1533+ (draft_tokens_gen == gen_target_tokens [:, :max_draft_len ]
1534+ ).int (),
1535+ dim = - 1 ).sum (dim = 1 )
1536+
1537+ # Vectorized extraction using advanced indexing (no GPU-CPU sync)
1538+ # Use num_accepted_tokens as indices to gather the right tokens
1539+ batch_indices = torch .arange (batch_size , device = device )
1540+ new_tokens [0 , :, 0 ] = target_tokens [num_accepted_tokens ,
1541+ batch_indices ]
1542+ else :
1543+ # No draft tokens to accept, just use the first (and only) sampled token
1544+ batch_indices = torch .arange (batch_size , device = device )
1545+ new_tokens [0 , :, 0 ] = target_tokens [0 , batch_indices ]
1546+
1547+ # Create the updated SampleStateTensorsMTP
1548+ # new_tokens_lens and next_draft_tokens are left as None
1549+ result_tensors = SampleStateTensorsMTP (
1550+ new_tokens = new_tokens ,
1551+ log_probs = target_outputs .log_probs ,
1552+ new_tokens_lens = None ,
1553+ next_draft_tokens = None )
1554+
1555+ # Copy logits if available
1556+ if hasattr (target_outputs , 'logits' ):
1557+ result_tensors .logits = target_outputs .logits
1558+
1559+ return result_tensors , num_accepted_tokens
1560+
14711561 def _process_previous_batch (self ):
14721562 if self .kv_cache_transceiver and self .previous_batch .ctx_transmission_reqs :
14731563 for req in self .previous_batch .ctx_transmission_reqs :
@@ -2364,7 +2454,8 @@ def _remove_inflight_ids(self, scheduled_requests):
23642454 for req in scheduled_requests .all_requests ():
23652455 self .inflight_req_ids .erase (req .request_id )
23662456
2367- def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ):
2457+ def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ,
2458+ target_inputs ):
23682459 with request_context (is_draft = self .draft_model_engine is not None ,
23692460 scheduled_requests = scheduled_batch ):
23702461 # Do an early checking to see if we need to forward the draft model.
@@ -2374,20 +2465,25 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
23742465 self .previous_batch is not None and self .use_spec_decode
23752466 and self .drafter .should_forward_draft_model (scheduled_batch ))
23762467
2377- if has_draft_batch or self .has_previous_draft_tokens :
2378- self ._update_requests (self .previous_batch .sample_state )
2379- if self .has_previous_draft_tokens :
2380- self ._prepare_draft_requests ()
2468+ new_target_inputs = None
2469+ if has_draft_batch :
2470+ target_outputs = self .previous_batch .sample_state and self .previous_batch .sample_state .device
2471+ assert target_outputs is not None , "target_outputs should not be None"
2472+ new_target_inputs , num_accepted_tokens_device = self ._accept_draft_tokens (
2473+ scheduled_batch = scheduled_batch ,
2474+ target_inputs = target_inputs ,
2475+ target_outputs = target_outputs )
23812476
23822477 if has_draft_batch :
2383- target_inputs , draft_outputs , draft_batch = self .drafter .generate_draft_tokens_with_overlap (
2478+ self .drafter .generate_draft_tokens_with_overlap (
23842479 scheduled_batch , self .resource_manager ,
2385- previous_tensors .device if previous_tensors else None )
2480+ previous_tensors .device if previous_tensors else None ,
2481+ new_target_inputs , num_accepted_tokens_device )
23862482
2387- self .has_previous_draft_tokens = target_inputs is not None and target_inputs .next_draft_tokens is not None
2483+ # Pad draft tokens to the max draft length for CUDA graph compatibility
2484+ self .has_previous_draft_tokens = new_target_inputs is not None and new_target_inputs .next_draft_tokens is not None
23882485 else :
23892486 self .has_previous_draft_tokens = False
2390- target_inputs , draft_outputs , draft_batch = None , None , None
23912487 # We are not running the draft model. Remove the draft tokens and turn off spec
23922488 # decode so that the requests get handled correctly.
23932489 # One corner case: when we have at least one context request, we have to keep spec
@@ -2400,34 +2496,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
24002496 for request in scheduled_batch .all_requests ():
24012497 request .py_draft_tokens = []
24022498
2403- return target_inputs , draft_outputs , draft_batch
2404-
2405- def _process_draft_results (self , scheduled_batch , draft_outputs ,
2406- draft_batch ):
2407- """
2408- Append the draft tokens to the target requests, and clean up the draft resources.
2409- """
2410- with request_context (is_draft = self .draft_model_engine is not None ,
2411- scheduled_requests = scheduled_batch ):
2412- req_id_to_old_request = {
2413- req .py_request_id : req
2414- for req in scheduled_batch .all_requests ()
2415- }
2416-
2417- if self .drafter .use_static_draft_loop :
2418- self .drafter .process_static_draft_outputs (
2419- draft_outputs , draft_batch , req_id_to_old_request )
2420- elif draft_outputs is not None :
2421- self .drafter .process_dynamic_draft_outputs (
2422- draft_outputs , req_id_to_old_request )
2423-
2424- # Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2425- self .drafter .pad_draft_tokens_for_cuda_graph (scheduled_batch )
2426- # add_batch must be called again to restore to target requests with updated draft tokens.
2427- if self .guided_decoder is not None :
2428- self .guided_decoder .add_batch (scheduled_batch )
2429- if hasattr (self .drafter , "guided_decoder" ):
2430- self .guided_decoder .rollback_draft_tokens ()
2499+ return new_target_inputs
24312500
24322501 def reset_prefix_cache (self ):
24332502 self .kv_cache_manager .reset_reuse_state ()
0 commit comments