@@ -1198,13 +1198,18 @@ def _executor_loop_overlap(self):
11981198 previous_tensors = self .previous_batch and self .previous_batch .sample_state
11991199 target_inputs = None
12001200 draft_outputs = None
1201- if self .drafter is not None and self .use_spec_decode :
1201+ # If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
1202+ # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
1203+ # so we'll set the target model's input to None and skip updating the target requests after target model forward.
1204+ use_previous_draft_tokens = self .has_previous_draft_tokens
1205+ if self .drafter is not None and (self .use_spec_decode or
1206+ use_previous_draft_tokens ):
12021207 target_inputs , draft_outputs , draft_batch = self ._handle_speculative_decoding (
12031208 scheduled_batch , previous_tensors )
12041209
12051210 # Use the draft_model's outputs if we've launched the draft model.
12061211 # Otherwise, use the previous batch's outputs.
1207- if target_inputs is not None :
1212+ if target_inputs is not None or use_previous_draft_tokens :
12081213 previous_tensors_device = target_inputs
12091214 else :
12101215 previous_tensors_device = self .previous_batch and self .previous_batch .sample_state and self .previous_batch .sample_state .device
@@ -1215,7 +1220,7 @@ def _executor_loop_overlap(self):
12151220 if target_inputs is not None :
12161221 self ._process_draft_results (scheduled_batch ,
12171222 draft_outputs , draft_batch )
1218- elif self .previous_batch is not None :
1223+ elif self .previous_batch is not None and not use_previous_draft_tokens :
12191224 self ._update_requests (self .previous_batch .sample_state )
12201225
12211226 if self .guided_decoder is not None :
@@ -1968,19 +1973,21 @@ def _remove_inflight_ids(self, scheduled_requests):
19681973 self .inflight_req_ids .erase (req .request_id )
19691974
19701975 def _handle_speculative_decoding (self , scheduled_batch , previous_tensors ):
1971- with request_context (is_draft = True , scheduled_requests = scheduled_batch ):
1976+ with request_context (is_draft = self .draft_model_engine is not None ,
1977+ scheduled_requests = scheduled_batch ):
19721978 # Do an early checking to see if we need to forward the draft model.
19731979 # If needed, the overlap should happen between the target requests and the draft requests.
19741980 # Otherwise, we can still do overlap between the previous target requests and the current target requests.
19751981 has_draft_batch = (
1976- self .previous_batch is not None
1982+ self .previous_batch is not None and self . use_spec_decode
19771983 and self .drafter .should_forward_draft_model (scheduled_batch ))
19781984
1979- if has_draft_batch :
1985+ if has_draft_batch or self . has_previous_draft_tokens :
19801986 self ._update_requests (self .previous_batch .sample_state )
19811987 if self .has_previous_draft_tokens :
19821988 self ._prepare_draft_requests ()
19831989
1990+ if has_draft_batch :
19841991 target_inputs , draft_outputs , draft_batch = self .drafter .generate_draft_tokens_with_overlap (
19851992 scheduled_batch , self .resource_manager ,
19861993 previous_tensors .device if previous_tensors else None )
@@ -1997,26 +2004,27 @@ def _process_draft_results(self, scheduled_batch, draft_outputs,
19972004 """
19982005 Append the draft tokens to the target requests, and clean up the draft resources.
19992006 """
2000- req_id_to_old_request = {
2001- req .py_request_id : req
2002- for req in scheduled_batch .all_requests ()
2003- }
2007+ with request_context (is_draft = self .draft_model_engine is not None ,
2008+ scheduled_requests = scheduled_batch ):
2009+ req_id_to_old_request = {
2010+ req .py_request_id : req
2011+ for req in scheduled_batch .all_requests ()
2012+ }
20042013
2005- if self .drafter .use_static_draft_loop :
2006- self .drafter .process_static_draft_outputs (draft_outputs ,
2007- draft_batch ,
2008- req_id_to_old_request )
2009- elif draft_outputs is not None :
2010- self .drafter .process_dynamic_draft_outputs (draft_outputs ,
2011- req_id_to_old_request )
2012-
2013- # Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2014- self .drafter .pad_draft_tokens_for_cuda_graph (scheduled_batch )
2015- # add_batch must be called again to restore to target requests with updated draft tokens.
2016- if self .guided_decoder is not None :
2017- self .guided_decoder .add_batch (scheduled_batch )
2018- if hasattr (self .drafter , "guided_decoder" ):
2019- self .guided_decoder .rollback_draft_tokens ()
2014+ if self .drafter .use_static_draft_loop :
2015+ self .drafter .process_static_draft_outputs (
2016+ draft_outputs , draft_batch , req_id_to_old_request )
2017+ elif draft_outputs is not None :
2018+ self .drafter .process_dynamic_draft_outputs (
2019+ draft_outputs , req_id_to_old_request )
2020+
2021+ # Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2022+ self .drafter .pad_draft_tokens_for_cuda_graph (scheduled_batch )
2023+ # add_batch must be called again to restore to target requests with updated draft tokens.
2024+ if self .guided_decoder is not None :
2025+ self .guided_decoder .add_batch (scheduled_batch )
2026+ if hasattr (self .drafter , "guided_decoder" ):
2027+ self .guided_decoder .rollback_draft_tokens ()
20202028
20212029
20222030class DisaggPPTerminationHandler :
0 commit comments