Skip to content

Commit 2e5850c

Browse files
authored
[TRTLLM-7330][feat] Eagle3 cuda graph support for the first draft model inference (NVIDIA#7363)
Signed-off-by: qgai <[email protected]>
1 parent f98fa0c commit 2e5850c

21 files changed

+299
-113
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from ...inputs.multimodal import MultimodalParams
99
from ..expert_statistic import ExpertStatistic
1010
from ..modules.multi_stream_utils import with_multi_stream
11+
from ..speculative.eagle3 import Eagle3ResourceManager
1112
from ..utils import make_weak_ref, piecewise_cuda_graph
12-
from .resource_manager import ResourceManager, ResourceManagerType
13+
from .resource_manager import (BaseResourceManager, ResourceManager,
14+
ResourceManagerType)
1315
from .scheduler import ScheduledRequests
1416

1517
if TYPE_CHECKING:
@@ -25,7 +27,7 @@ class CUDAGraphRunner:
2527
2628
This unified class handles high-level orchestration (padding, eligibility)
2729
and low-level execution (capturing, resource management, replaying) for
28-
multiple graphs, keyed by (batch size, draft_len).
30+
multiple graphs, keyed by (batch size, draft_len, is_first_draft).
2931
"""
3032
WARMUP_STEPS = 2
3133

@@ -41,10 +43,10 @@ def __init__(self, engine: "PyTorchModelEngine"):
4143
self.max_beam_width = engine.max_beam_width
4244
self.spec_config = engine.spec_config
4345

44-
self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {}
45-
self.graph_outputs: Dict[Tuple[int, int],
46+
self.graphs: Dict[Tuple[int, int, int], torch.cuda.CUDAGraph] = {}
47+
self.graph_outputs: Dict[Tuple[int, int, int],
4648
Callable[[], Optional[torch.Tensor]]] = {}
47-
self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {}
49+
self.graph_metadata: Dict[Tuple[int, int, int], Dict[str, Any]] = {}
4850
self.memory_pool = engine._cuda_graph_mem_pool
4951
self.padding_dummy_request: Optional["Request"] = None
5052

@@ -56,7 +58,7 @@ def _create_shared_static_tensors(self):
5658
"""Allocates static tensors sized for the largest possible batch."""
5759
engine = self._get_engine()
5860

59-
token_per_request = self.draft_len + 1
61+
token_per_request = self.max_possible_draft_len + 1
6062
max_total_tokens = (self.max_supported_batch_size *
6163
self.max_beam_width * token_per_request)
6264
max_total_tokens = min(max_total_tokens, engine.max_num_tokens)
@@ -87,8 +89,23 @@ def enable_spec_decode(self):
8789
return self._get_engine().enable_spec_decode
8890

8991
@property
90-
def draft_len(self):
91-
return self.spec_config.max_draft_len if self.enable_spec_decode else 0
92+
def max_possible_draft_len(self):
93+
engine = self._get_engine()
94+
return (engine.original_max_draft_len if self.enable_spec_decode else 0)
95+
96+
def get_graph_key(
97+
self,
98+
batch_size,
99+
spec_resource_manager: Optional[BaseResourceManager] = None):
100+
engine = self._get_engine()
101+
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
102+
spec_resource_manager, Eagle3ResourceManager):
103+
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
104+
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
105+
else:
106+
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
107+
key = (batch_size, draft_len, False)
108+
return key
92109

93110
@property
94111
def spec_metadata(self):
@@ -113,21 +130,25 @@ def _get_engine(self) -> "PyTorchModelEngine":
113130
"The parent PyTorchModelEngine has been garbage collected.")
114131
return engine
115132

116-
def maybe_get_cuda_graph(self, batch: ScheduledRequests):
133+
def maybe_get_cuda_graph(
134+
self,
135+
batch: ScheduledRequests,
136+
spec_resource_manager: Optional[BaseResourceManager] = None):
117137
"""
118138
Determines if the current batch can be run with a CUDA graph.
119139
120140
Returns a tuple containing:
121141
- A boolean indicating if a graph can be used.
122142
- The attn_metadata for the graph, if applicable.
123143
- The spec_metadata for the graph, if applicable.
144+
- The key for the graph.
124145
"""
125146
engine = self._get_engine()
126147

127148
# disable when doing statistic
128149
if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter(
129150
engine.iter_counter):
130-
return False, None, None
151+
return False, None, None, None
131152

132153
can_run_cuda_graph = batch.can_run_cuda_graph
133154
batch_size = batch.batch_size
@@ -141,22 +162,22 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
141162
for all_gen_only in all_can_graph_batch)
142163

143164
if not is_all_gen_only or not all_batch_size_equal:
144-
return False, None, None
165+
return False, None, None, None
145166

146167
if not self.enabled or not can_run_cuda_graph:
147-
return False, None, None
168+
return False, None, None, None
169+
key = self.get_graph_key(batch_size, spec_resource_manager)
148170

149-
key = (batch_size, self.draft_len)
150171
if key in self.graphs:
151172
return True, self.graph_metadata[key][
152-
"attn_metadata"], self.graph_metadata[key]["spec_metadata"]
173+
"attn_metadata"], self.graph_metadata[key]["spec_metadata"], key
153174

154175
if batch_size not in self.supported_batch_sizes:
155-
return False, None, None
176+
return False, None, None, None
156177

157178
num_sequences_in_batch = batch_size * self.max_beam_width
158179
attn_metadata = self.attn_metadata.create_cuda_graph_metadata(
159-
num_sequences_in_batch, False, self.draft_len)
180+
num_sequences_in_batch, False, key[1])
160181
assert attn_metadata.is_cuda_graph
161182

162183
if self.enable_spec_decode:
@@ -165,23 +186,25 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
165186
spec_metadata.draft_tokens = self.draft_tokens_cuda
166187
else:
167188
spec_metadata = None
168-
return True, attn_metadata, spec_metadata
189+
return True, attn_metadata, spec_metadata, key
190+
191+
def needs_capture(self, key: Tuple[int, int, int]):
169192

170-
def needs_capture(self, batch_size: int):
171-
return (batch_size, self.draft_len) not in self.graph_outputs
193+
return key not in self.graph_outputs
172194

173195
def capture(self,
174-
batch_size: int,
196+
key: Tuple[int, int, int],
175197
forward_fn: Callable,
176198
initial_inputs: Dict[str, Any],
177199
postprocess_fn: Optional[Callable] = None):
178200
"""Captures the forward pass for a given batch size."""
179201
engine = self._get_engine()
180-
key = (batch_size, self.draft_len)
202+
batch_size = key[0]
181203
# [CUDA graph spec decode padding]
182204
# We pad input IDs/position IDs to the maximum draft length (token per request).
183205
# We're forced to do this because we cannot reallocate inputs over many graph runs.
184-
token_per_request = self.draft_len + 1
206+
max_draft_len = key[1]
207+
token_per_request = max_draft_len + 1
185208
num_tokens_for_capture = (batch_size * self.max_beam_width *
186209
token_per_request)
187210

@@ -207,30 +230,43 @@ def capture(self,
207230
"spec_metadata": initial_inputs.get("spec_metadata", None),
208231
}
209232

233+
def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
234+
forward_fn: Callable,
235+
capture_inputs: Dict[str, Any]):
236+
engine = self._get_engine()
237+
# for the first inference of draft model, we need to set the use_spec_decoding to True when capture the graph for multiple runs.
238+
is_first_draft = key[2]
239+
needs_kv_cache_recompute = True if engine.enable_spec_decode and engine.spec_config.spec_dec_mode.needs_kv_cache_recompute(
240+
) else False
241+
if is_first_draft and engine.is_draft_model and needs_kv_cache_recompute:
242+
capture_inputs['attn_metadata'].use_spec_decoding = True
243+
return forward_fn(capture_inputs)
244+
210245
# We have to do warm up runs to initialize PyTorch's
211246
# internal states according to the docs:
212247
# https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics
213248
# This also lets us initialize states in the attn_metadata.
214249
graph = torch.cuda.CUDAGraph()
215250
with with_multi_stream(True), piecewise_cuda_graph(False):
216251
for _ in range(self.WARMUP_STEPS):
217-
forward_fn(capture_inputs)
252+
_setup_spec_decoding_and_forward(key, forward_fn,
253+
capture_inputs)
218254
if postprocess_fn is not None:
219255
postprocess_fn(capture_inputs)
220256
with torch.cuda.graph(graph, pool=self.memory_pool):
221-
output = forward_fn(capture_inputs)
257+
output = _setup_spec_decoding_and_forward(
258+
key, forward_fn, capture_inputs)
222259
if postprocess_fn is not None:
223260
postprocess_fn(capture_inputs)
224261

225262
self.graphs[key] = graph
226263
self.graph_outputs[key] = make_weak_ref(output)
227264
self.memory_pool = graph.pool()
228265

229-
def replay(self, batch_size: int,
266+
def replay(self, key: Tuple[int, int, int],
230267
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
231268
"""Replays a previously captured graph."""
232269
engine = self._get_engine()
233-
key = (batch_size, self.draft_len)
234270
stored_meta = self.graph_metadata[key]
235271
assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"]
236272
if stored_meta["spec_metadata"] is not None:

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def __init__(
311311
is_draft: bool = False,
312312
seq_slot: Optional[int] = None,
313313
target_seq_slot: Optional[int] = None,
314+
is_first_draft: bool = False,
314315
**kwargs):
315316

316317
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
@@ -365,6 +366,8 @@ def __init__(
365366
# If the request is a draft request, target_seq_slot is the sequence slot ID of its target request.
366367
self.py_target_seq_slot = target_seq_slot
367368
self.use_draft_model = is_draft
369+
# Whether the request is for the first forward of the draft model.
370+
self.py_is_first_draft = is_first_draft
368371

369372
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
370373
# currently, keep py_stop_words_list as python list, rather than tensor.

0 commit comments

Comments
 (0)