Skip to content

Commit 897c4dd

Browse files
authored
[https://nvbugs/5517404][fix] Use the correct cuda graph for dynamic spec dec (NVIDIA#7728)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 4509d97 commit 897c4dd

File tree

9 files changed

+174
-129
lines changed

9 files changed

+174
-129
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,6 @@ def __init__(self, engine: "PyTorchModelEngine"):
4040
self.max_beam_width = engine.max_beam_width
4141
self.spec_config = engine.spec_config
4242

43-
self.max_possible_draft_len = (self.spec_config.max_draft_len
44-
if self.enable_spec_decode else 0)
45-
4643
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
4744
self.graph_outputs: Dict[Tuple[int, int],
4845
Callable[[], Optional[torch.Tensor]]] = {}
@@ -58,7 +55,7 @@ def _create_shared_static_tensors(self):
5855
"""Allocates static tensors sized for the largest possible batch."""
5956
engine = self._get_engine()
6057

61-
token_per_request = self.max_possible_draft_len + 1
58+
token_per_request = self.draft_len + 1
6259
max_total_tokens = (self.max_supported_batch_size *
6360
self.max_beam_width * token_per_request)
6461
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
@@ -78,7 +75,7 @@ def _create_shared_static_tensors(self):
7875

7976
@property
8077
def enable_spec_decode(self):
81-
return self._get_engine().is_spec_decode
78+
return self._get_engine().enable_spec_decode
8279

8380
@property
8481
def draft_len(self):
@@ -174,7 +171,7 @@ def capture(self,
174171
# [CUDA graph spec decode padding]
175172
# We pad input IDs/position IDs to the maximum draft length (token per request).
176173
# We're forced to do this because we cannot reallocate inputs over many graph runs.
177-
token_per_request = self.max_possible_draft_len + 1
174+
token_per_request = self.draft_len + 1
178175
num_tokens_for_capture = (batch_size * self.max_beam_width *
179176
token_per_request)
180177

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,7 +1511,6 @@ def _prepare_tp_inputs(
15111511
prompt_lengths.append(1 + self.runtime_draft_len)
15121512
else:
15131513
prompt_lengths.append(request.py_prompt_len)
1514-
15151514
for request in generation_requests:
15161515
request_ids.append(request.py_request_id)
15171516
beam_width = request.sampling_config.beam_width
@@ -1534,7 +1533,6 @@ def _prepare_tp_inputs(
15341533
if beam == first_beam:
15351534
previous_batch_indices.append(request.py_batch_idx)
15361535
past_seen_token_num = request.max_beam_num_tokens
1537-
15381536
position_ids.append(past_seen_token_num)
15391537
num_cached_tokens_per_seq.append(past_seen_token_num)
15401538
prompt_lengths.append(request.py_prompt_len)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,13 +1198,18 @@ def _executor_loop_overlap(self):
11981198
previous_tensors = self.previous_batch and self.previous_batch.sample_state
11991199
target_inputs = None
12001200
draft_outputs = None
1201-
if self.drafter is not None and self.use_spec_decode:
1201+
# If there are previous draft tokens, we need to update the target requests to accept some draft tokens.
1202+
# When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model,
1203+
# so we'll set the target model's input to None and skip updating the target requests after target model forward.
1204+
use_previous_draft_tokens = self.has_previous_draft_tokens
1205+
if self.drafter is not None and (self.use_spec_decode or
1206+
use_previous_draft_tokens):
12021207
target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding(
12031208
scheduled_batch, previous_tensors)
12041209

12051210
# Use the draft_model's outputs if we've launched the draft model.
12061211
# Otherwise, use the previous batch's outputs.
1207-
if target_inputs is not None:
1212+
if target_inputs is not None or use_previous_draft_tokens:
12081213
previous_tensors_device = target_inputs
12091214
else:
12101215
previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device
@@ -1215,7 +1220,7 @@ def _executor_loop_overlap(self):
12151220
if target_inputs is not None:
12161221
self._process_draft_results(scheduled_batch,
12171222
draft_outputs, draft_batch)
1218-
elif self.previous_batch is not None:
1223+
elif self.previous_batch is not None and not use_previous_draft_tokens:
12191224
self._update_requests(self.previous_batch.sample_state)
12201225

12211226
if self.guided_decoder is not None:
@@ -1968,19 +1973,21 @@ def _remove_inflight_ids(self, scheduled_requests):
19681973
self.inflight_req_ids.erase(req.request_id)
19691974

19701975
def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
1971-
with request_context(is_draft=True, scheduled_requests=scheduled_batch):
1976+
with request_context(is_draft=self.draft_model_engine is not None,
1977+
scheduled_requests=scheduled_batch):
19721978
# Do an early checking to see if we need to forward the draft model.
19731979
# If needed, the overlap should happen between the target requests and the draft requests.
19741980
# Otherwise, we can still do overlap between the previous target requests and the current target requests.
19751981
has_draft_batch = (
1976-
self.previous_batch is not None
1982+
self.previous_batch is not None and self.use_spec_decode
19771983
and self.drafter.should_forward_draft_model(scheduled_batch))
19781984

1979-
if has_draft_batch:
1985+
if has_draft_batch or self.has_previous_draft_tokens:
19801986
self._update_requests(self.previous_batch.sample_state)
19811987
if self.has_previous_draft_tokens:
19821988
self._prepare_draft_requests()
19831989

1990+
if has_draft_batch:
19841991
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
19851992
scheduled_batch, self.resource_manager,
19861993
previous_tensors.device if previous_tensors else None)
@@ -1997,26 +2004,27 @@ def _process_draft_results(self, scheduled_batch, draft_outputs,
19972004
"""
19982005
Append the draft tokens to the target requests, and clean up the draft resources.
19992006
"""
2000-
req_id_to_old_request = {
2001-
req.py_request_id: req
2002-
for req in scheduled_batch.all_requests()
2003-
}
2007+
with request_context(is_draft=self.draft_model_engine is not None,
2008+
scheduled_requests=scheduled_batch):
2009+
req_id_to_old_request = {
2010+
req.py_request_id: req
2011+
for req in scheduled_batch.all_requests()
2012+
}
20042013

2005-
if self.drafter.use_static_draft_loop:
2006-
self.drafter.process_static_draft_outputs(draft_outputs,
2007-
draft_batch,
2008-
req_id_to_old_request)
2009-
elif draft_outputs is not None:
2010-
self.drafter.process_dynamic_draft_outputs(draft_outputs,
2011-
req_id_to_old_request)
2012-
2013-
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2014-
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
2015-
# add_batch must be called again to restore to target requests with updated draft tokens.
2016-
if self.guided_decoder is not None:
2017-
self.guided_decoder.add_batch(scheduled_batch)
2018-
if hasattr(self.drafter, "guided_decoder"):
2019-
self.guided_decoder.rollback_draft_tokens()
2014+
if self.drafter.use_static_draft_loop:
2015+
self.drafter.process_static_draft_outputs(
2016+
draft_outputs, draft_batch, req_id_to_old_request)
2017+
elif draft_outputs is not None:
2018+
self.drafter.process_dynamic_draft_outputs(
2019+
draft_outputs, req_id_to_old_request)
2020+
2021+
# Pad draft tokens to the max draft length. This is for CUDA graph compatibility.
2022+
self.drafter.pad_draft_tokens_for_cuda_graph(scheduled_batch)
2023+
# add_batch must be called again to restore to target requests with updated draft tokens.
2024+
if self.guided_decoder is not None:
2025+
self.guided_decoder.add_batch(scheduled_batch)
2026+
if hasattr(self.drafter, "guided_decoder"):
2027+
self.guided_decoder.rollback_draft_tokens()
20202028

20212029

20222030
class DisaggPPTerminationHandler:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@
4141
from .py_executor import PyExecutor
4242

4343

44-
# Development flag to control chain drafter feature
44+
# Development function to control chain drafter feature.
45+
# It's here so that unit tests can mock it and turn it off.
4546
def _get_allow_chain_drafter() -> bool:
46-
"""Get the chain drafter flag from environment variable."""
47-
# Use environment variable for cross-process compatibility
48-
return os.getenv("TRTLLM_ALLOW_CHAIN_DRAFTER", "0") == "1"
47+
return True
4948

5049

5150
class _ExecutorCreationStage(enum.Enum):

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,9 @@ def update_requests(self, state: SampleState) -> None:
563563
if get_draft_token_length(req) > 0:
564564
req.py_num_accepted_draft_tokens = num_accepted
565565
req.py_rewind_len = req.py_draft_pages_allocated - num_accepted
566+
else:
567+
req.py_num_accepted_draft_tokens = 0
568+
req.py_rewind_len = 0
566569
processed += num_accepted
567570
self.handle_logprobs(req, state, beam=self.BEAM, count=processed)
568571
req.py_decoding_iter += 1

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def _convert_draft_tensors(
396396
new_tokens_lens = None
397397
next_draft_tokens = None
398398
has_draft_tokens = False
399+
batch_size = new_tokens.shape[1]
399400
# Iterate through generation requests and copy tokens based on accepted draft tokens
400401
for request in scheduled_batch.all_requests():
401402
idx = request.py_seq_slot
@@ -411,9 +412,8 @@ def _convert_draft_tensors(
411412

412413
if has_draft_tokens:
413414
# We already updated the target state, so the new_tokens_lens should be all ones.
414-
new_tokens_lens = torch.ones(scheduled_batch.batch_size,
415-
device=device)
416-
next_draft_tokens = torch.zeros(scheduled_batch.batch_size,
415+
new_tokens_lens = torch.ones(batch_size, device=device)
416+
next_draft_tokens = torch.zeros(batch_size,
417417
self.max_draft_tokens,
418418
device=device)
419419

@@ -438,15 +438,15 @@ def _update_target_inputs_with_draft_tokens(
438438
Update target inputs with new draft tokens from sample state.
439439
"""
440440
if draft_tensors is not None:
441-
for request in draft_batch.all_requests():
441+
for req_idx, request in enumerate(draft_batch.all_requests()):
442442
# Skip prefill requests
443443
if target_inputs.next_draft_tokens is None:
444444
continue
445445

446446
# Get the index of the draft/target tokens in the device tensor
447-
draft_idx = request.py_seq_slot
447+
draft_idx = req_idx if self.use_static_draft_loop else request.py_batch_idx
448448
target_idx = req_id_to_old_request[
449-
request.py_request_id].py_seq_slot
449+
request.py_request_id].py_batch_idx
450450
target_inputs.new_tokens[draft_position + 1:draft_position +
451451
draft_length + 1, target_idx,
452452
0] = draft_tensors[0:draft_length,

tests/unittest/_torch/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def create_mock_engine(batch_size: int):
188188
max_beam_width=1,
189189
max_num_tokens=8192,
190190
is_spec_decode=False,
191+
enable_spec_decode=False,
191192
spec_config=None,
192193
_cuda_graph_mem_pool=None,
193194
use_mrope=False,

tests/unittest/_torch/speculative/test_dynamic_spec_decode.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
11
import os
22
import sys
33
import unittest
4-
from unittest.mock import patch
4+
from unittest.mock import Mock, patch
55

66
import pytest
77
import torch
88
from utils.llm_data import llm_models_root
99

1010
from tensorrt_llm import LLM, SamplingParams
11+
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequestState
1112
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
1213
KvCacheConfig)
1314

1415
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
1516

1617

18+
@pytest.fixture(scope="function")
19+
def enforce_single_worker(monkeypatch):
20+
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
21+
yield
22+
23+
1724
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
1825
@pytest.mark.high_cuda_memory
19-
def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
26+
def test_dynamic_spec_decode(enforce_single_worker,
27+
disable_overlap_scheduler: bool):
28+
# mock_should_use_spec_decode doesn't work with multiple processes,
29+
# so we use the enforce_single_worker fixture to set the environment variable.
2030
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
2131
if total_mem_gb < 35:
2232
pytest.skip("Not enough memory to load target + draft model")
@@ -51,32 +61,42 @@ def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
5161
eagle3_one_model=False,
5262
)
5363

54-
# Mock should_use_spec_decode to return True for first two calls, then False
64+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
65+
sampling_params = SamplingParams(max_tokens=128, temperature=0)
66+
67+
# Output tests
68+
prompts = [
69+
"The president of the United States is",
70+
]
71+
sampling_params = SamplingParams(max_tokens=20, temperature=0)
72+
73+
# Mock should_use_spec_decode to turn on/off spec decode dynamically.
5574
def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens,
5675
max_draft_len):
57-
if not hasattr(mock_should_use_spec_decode, 'call_count'):
58-
mock_should_use_spec_decode.call_count = 0
59-
mock_should_use_spec_decode.call_count += 1
60-
return mock_should_use_spec_decode.call_count <= 2
76+
for req in requests:
77+
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
78+
continue
79+
80+
mock_should_use_spec_decode.call_count += 1
81+
# Turn off spec decode when we've called it 5 times.
82+
# In the current case, at the 5th call, there are 2 accepted draft tokens,
83+
# so we can have better coverage for the switching between spec decode on and off.
84+
if mock_should_use_spec_decode.call_count > 5:
85+
return False
86+
return True
87+
88+
# Create a Mock object with the mock function as side_effect
89+
mock_should_use_spec_decode = Mock(side_effect=mock_should_use_spec_decode)
90+
# Reset mock state before using it
91+
mock_should_use_spec_decode.reset_mock()
92+
mock_should_use_spec_decode.call_count = 0
6193

6294
with patch(
6395
'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode',
64-
side_effect=mock_should_use_spec_decode):
65-
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
66-
sampling_params = SamplingParams(max_tokens=128, temperature=0)
67-
68-
# Output tests
69-
prompts = [
70-
"The capital of France is",
71-
"The president of the United States is",
72-
]
73-
sampling_params = SamplingParams(max_tokens=10, temperature=0)
74-
96+
mock_should_use_spec_decode):
7597
results_spec = llm_spec.generate(prompts, sampling_params)
76-
generated_text_spec = [
77-
result.outputs[0].text for result in results_spec
78-
]
79-
llm_spec.shutdown()
98+
generated_text_spec = [result.outputs[0].text for result in results_spec]
99+
llm_spec.shutdown()
80100

81101
llm_ref = LLM(**llm_common_config)
82102
results_ref = llm_ref.generate(prompts, sampling_params)

0 commit comments

Comments
 (0)