File tree Expand file tree Collapse file tree 3 files changed +17
-13
lines changed Expand file tree Collapse file tree 3 files changed +17
-13
lines changed Original file line number Diff line number Diff line change @@ -535,14 +535,15 @@ def update_resources(self,
535535 scheduled_batch : ScheduledRequests ,
536536 attn_metadata : "AttentionMetadata" = None ,
537537 kv_cache_dtype_byte_size : float = None ):
538- self .update_kv_cache_draft_token_location (scheduled_batch ,
539- attn_metadata ,
540- kv_cache_dtype_byte_size )
541- # rewind kv cache
542- for request in scheduled_batch .generation_requests :
543- if request .state != LlmRequestState .GENERATION_COMPLETE :
544- if request .py_rewind_len > 0 :
545- self .rewind_kv_cache (request , request .py_rewind_len )
538+ if not self .is_draft :
539+ self .update_kv_cache_draft_token_location (scheduled_batch ,
540+ attn_metadata ,
541+ kv_cache_dtype_byte_size )
542+ # rewind kv cache
543+ for request in scheduled_batch .generation_requests :
544+ if request .state != LlmRequestState .GENERATION_COMPLETE :
545+ if request .py_rewind_len > 0 :
546+ self .rewind_kv_cache (request , request .py_rewind_len )
546547
547548 # For context requests, we store the blocks for reuse.
548549 for request in scheduled_batch .context_requests :
Original file line number Diff line number Diff line change @@ -879,9 +879,9 @@ def _process_draft_tokens_tree(
879879
880880 assert num_accepted_draft_tokens <= longest_accepted_len
881881
882- request . py_num_accepted_draft_tokens_indices = eagle_paths [longest_match_path_idx ][
883- 1 : num_accepted_draft_tokens
884- ]. tolist () # exclude the root node
882+ tree_node_indices = eagle_paths [longest_match_path_idx ][1 : num_accepted_draft_tokens ]
883+ request . py_num_accepted_draft_tokens_indices = ( tree_node_indices - 1 ). tolist ()
884+
885885 return num_accepted_draft_tokens - 1
886886
887887 @torch .inference_mode ()
Original file line number Diff line number Diff line change @@ -202,7 +202,11 @@ def prepare(self):
202202 elif is_first_draft and spec_tree_manager is not None :
203203 assert req_id in self .request_accepted_path .keys (
204204 ), f"Request { req_id } not found in request_accepted_path"
205- accepted_path = self .request_accepted_path [req_id ]
205+ # 'node_idx + 1' is because we '-1' in sampler.py for kv cache rewind. Now we add it back.
206+ accepted_path = [
207+ node_idx + 1
208+ for node_idx in self .request_accepted_path [req_id ]
209+ ]
206210
207211 if accepted_path == []:
208212 # Case 1: This is a context request, We need to read all the hidden states.
@@ -218,7 +222,6 @@ def prepare(self):
218222 assert len (accepted_path_pad ) == seq_len
219223 hidden_states_read_indices .extend ([
220224 start_idx + accepted_draft_token_offset
221- # for accepted_draft_token_offset in accepted_path
222225 for accepted_draft_token_offset in accepted_path_pad
223226 ])
224227
You can’t perform that action at this time.
0 commit comments