Skip to content

Commit 2dd3186

Browse files
authored
fix: remove cudaStreamSynchronize when using relaxed acceptance (NVIDIA#5262)
Signed-off-by: Yue Weng <[email protected]>
1 parent 908f49a commit 2dd3186

File tree

1 file changed

+4
-2
lines changed
  • tensorrt_llm/_torch/speculative

1 file changed

+4
-2
lines changed

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,17 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
6767
if req.is_first_context_chunk:
6868
slot_id = self.slot_manager.add_slot(req.request_id)
6969
if self.use_relaxed_acceptance_for_thinking:
70-
self.mtp_relaxed_delta_pool[slot_id] = 0.
70+
self.mtp_relaxed_delta_pool[slot_id].copy_(
71+
0, non_blocking=True)
7172

7273
def update_resources(self, scheduled_batch: ScheduledRequests):
7374
pass
7475

7576
def free_resources(self, request: LlmRequest):
7677
free_slot_id = self.slot_manager.get_slot(request.request_id)
7778
if self.use_relaxed_acceptance_for_thinking:
78-
self.mtp_relaxed_delta_pool[free_slot_id] = 0.
79+
self.mtp_relaxed_delta_pool[free_slot_id].copy_(0,
80+
non_blocking=True)
7981
self.slot_manager.remove_slot(request.request_id)
8082

8183
def add_dummy_requests(self, request_ids: List[int]):

0 commit comments

Comments
 (0)