Skip to content

Commit 8977e4e

Browse files
committed
fix kv cache rewind issue
Signed-off-by: Yue Weng <[email protected]>
1 parent 895581b commit 8977e4e

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -535,14 +535,15 @@ def update_resources(self,
535535
scheduled_batch: ScheduledRequests,
536536
attn_metadata: "AttentionMetadata" = None,
537537
kv_cache_dtype_byte_size: float = None):
538-
self.update_kv_cache_draft_token_location(scheduled_batch,
539-
attn_metadata,
540-
kv_cache_dtype_byte_size)
541-
# rewind kv cache
542-
for request in scheduled_batch.generation_requests:
543-
if request.state != LlmRequestState.GENERATION_COMPLETE:
544-
if request.py_rewind_len > 0:
545-
self.rewind_kv_cache(request, request.py_rewind_len)
538+
if not self.is_draft:
539+
self.update_kv_cache_draft_token_location(scheduled_batch,
540+
attn_metadata,
541+
kv_cache_dtype_byte_size)
542+
# rewind kv cache
543+
for request in scheduled_batch.generation_requests:
544+
if request.state != LlmRequestState.GENERATION_COMPLETE:
545+
if request.py_rewind_len > 0:
546+
self.rewind_kv_cache(request, request.py_rewind_len)
546547

547548
# For context requests, we store the blocks for reuse.
548549
for request in scheduled_batch.context_requests:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -879,9 +879,9 @@ def _process_draft_tokens_tree(
879879

880880
assert num_accepted_draft_tokens <= longest_accepted_len
881881

882-
request.py_num_accepted_draft_tokens_indices = eagle_paths[longest_match_path_idx][
883-
1:num_accepted_draft_tokens
884-
].tolist() # exclude the root node
882+
tree_node_indices = eagle_paths[longest_match_path_idx][1:num_accepted_draft_tokens]
883+
request.py_num_accepted_draft_tokens_indices = (tree_node_indices - 1).tolist()
884+
885885
return num_accepted_draft_tokens - 1
886886

887887
@torch.inference_mode()

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,11 @@ def prepare(self):
202202
elif is_first_draft and spec_tree_manager is not None:
203203
assert req_id in self.request_accepted_path.keys(
204204
), f"Request {req_id} not found in request_accepted_path"
205-
accepted_path = self.request_accepted_path[req_id]
205+
# 'node_idx + 1' is because we '-1' in sampler.py for kv cache rewind. Now we add it back.
206+
accepted_path = [
207+
node_idx + 1
208+
for node_idx in self.request_accepted_path[req_id]
209+
]
206210

207211
if accepted_path == []:
208212
# Case 1: This is a context request, We need to read all the hidden states.
@@ -218,7 +222,6 @@ def prepare(self):
218222
assert len(accepted_path_pad) == seq_len
219223
hidden_states_read_indices.extend([
220224
start_idx + accepted_draft_token_offset
221-
# for accepted_draft_token_offset in accepted_path
222225
for accepted_draft_token_offset in accepted_path_pad
223226
])
224227

0 commit comments

Comments
 (0)