88from ...inputs .multimodal import MultimodalParams
99from ..expert_statistic import ExpertStatistic
1010from ..modules .multi_stream_utils import with_multi_stream
11+ from ..speculative .eagle3 import Eagle3ResourceManager
1112from ..utils import make_weak_ref , piecewise_cuda_graph
12- from .resource_manager import ResourceManager , ResourceManagerType
13+ from .resource_manager import (BaseResourceManager , ResourceManager ,
14+ ResourceManagerType )
1315from .scheduler import ScheduledRequests
1416
1517if TYPE_CHECKING :
@@ -25,7 +27,7 @@ class CUDAGraphRunner:
2527
2628 This unified class handles high-level orchestration (padding, eligibility)
2729 and low-level execution (capturing, resource management, replaying) for
28- multiple graphs, keyed by (batch size, draft_len).
30+ multiple graphs, keyed by (batch size, draft_len, is_first_draft ).
2931 """
3032 WARMUP_STEPS = 2
3133
@@ -41,10 +43,10 @@ def __init__(self, engine: "PyTorchModelEngine"):
4143 self .max_beam_width = engine .max_beam_width
4244 self .spec_config = engine .spec_config
4345
44- self .graphs : Dict [Tuple [int , int ], torch .cuda .CUDAGraph ] = {}
45- self .graph_outputs : Dict [Tuple [int , int ],
46+ self .graphs : Dict [Tuple [int , int , int ], torch .cuda .CUDAGraph ] = {}
47+ self .graph_outputs : Dict [Tuple [int , int , int ],
4648 Callable [[], Optional [torch .Tensor ]]] = {}
47- self .graph_metadata : Dict [Tuple [int , int ], Dict [str , Any ]] = {}
49+ self .graph_metadata : Dict [Tuple [int , int , int ], Dict [str , Any ]] = {}
4850 self .memory_pool = engine ._cuda_graph_mem_pool
4951 self .padding_dummy_request : Optional ["Request" ] = None
5052
@@ -56,7 +58,7 @@ def _create_shared_static_tensors(self):
5658 """Allocates static tensors sized for the largest possible batch."""
5759 engine = self ._get_engine ()
5860
59- token_per_request = self .draft_len + 1
61+ token_per_request = self .max_possible_draft_len + 1
6062 max_total_tokens = (self .max_supported_batch_size *
6163 self .max_beam_width * token_per_request )
6264 max_total_tokens = min (max_total_tokens , engine .max_num_tokens )
@@ -87,8 +89,23 @@ def enable_spec_decode(self):
8789 return self ._get_engine ().enable_spec_decode
8890
8991 @property
90- def draft_len (self ):
91- return self .spec_config .max_draft_len if self .enable_spec_decode else 0
92+ def max_possible_draft_len (self ):
93+ engine = self ._get_engine ()
94+ return (engine .original_max_draft_len if self .enable_spec_decode else 0 )
95+
96+ def get_graph_key (
97+ self ,
98+ batch_size ,
99+ spec_resource_manager : Optional [BaseResourceManager ] = None ):
100+ engine = self ._get_engine ()
101+ if engine .is_draft_model and spec_resource_manager is not None and isinstance (
102+ spec_resource_manager , Eagle3ResourceManager ):
103+ draft_len = engine .original_max_draft_len if spec_resource_manager .is_first_draft else 0
104+ key = (batch_size , draft_len , spec_resource_manager .is_first_draft )
105+ else :
106+ draft_len = self .spec_config .max_draft_len if self .enable_spec_decode else 0
107+ key = (batch_size , draft_len , False )
108+ return key
92109
93110 @property
94111 def spec_metadata (self ):
@@ -113,21 +130,25 @@ def _get_engine(self) -> "PyTorchModelEngine":
113130 "The parent PyTorchModelEngine has been garbage collected." )
114131 return engine
115132
116- def maybe_get_cuda_graph (self , batch : ScheduledRequests ):
133+ def maybe_get_cuda_graph (
134+ self ,
135+ batch : ScheduledRequests ,
136+ spec_resource_manager : Optional [BaseResourceManager ] = None ):
117137 """
118138 Determines if the current batch can be run with a CUDA graph.
119139
120140 Returns a tuple containing:
121141 - A boolean indicating if a graph can be used.
122142 - The attn_metadata for the graph, if applicable.
123143 - The spec_metadata for the graph, if applicable.
144+ - The key for the graph.
124145 """
125146 engine = self ._get_engine ()
126147
127148 # disable when doing statistic
128149 if hasattr (engine , 'iter_counter' ) and ExpertStatistic .set_iter (
129150 engine .iter_counter ):
130- return False , None , None
151+ return False , None , None , None
131152
132153 can_run_cuda_graph = batch .can_run_cuda_graph
133154 batch_size = batch .batch_size
@@ -141,22 +162,22 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
141162 for all_gen_only in all_can_graph_batch )
142163
143164 if not is_all_gen_only or not all_batch_size_equal :
144- return False , None , None
165+ return False , None , None , None
145166
146167 if not self .enabled or not can_run_cuda_graph :
147- return False , None , None
168+ return False , None , None , None
169+ key = self .get_graph_key (batch_size , spec_resource_manager )
148170
149- key = (batch_size , self .draft_len )
150171 if key in self .graphs :
151172 return True , self .graph_metadata [key ][
152- "attn_metadata" ], self .graph_metadata [key ]["spec_metadata" ]
173+ "attn_metadata" ], self .graph_metadata [key ]["spec_metadata" ], key
153174
154175 if batch_size not in self .supported_batch_sizes :
155- return False , None , None
176+ return False , None , None , None
156177
157178 num_sequences_in_batch = batch_size * self .max_beam_width
158179 attn_metadata = self .attn_metadata .create_cuda_graph_metadata (
159- num_sequences_in_batch , False , self . draft_len )
180+ num_sequences_in_batch , False , key [ 1 ] )
160181 assert attn_metadata .is_cuda_graph
161182
162183 if self .enable_spec_decode :
@@ -165,23 +186,25 @@ def maybe_get_cuda_graph(self, batch: ScheduledRequests):
165186 spec_metadata .draft_tokens = self .draft_tokens_cuda
166187 else :
167188 spec_metadata = None
168- return True , attn_metadata , spec_metadata
189+ return True , attn_metadata , spec_metadata , key
190+
191+ def needs_capture (self , key : Tuple [int , int , int ]):
169192
170- def needs_capture (self , batch_size : int ):
171- return (batch_size , self .draft_len ) not in self .graph_outputs
193+ return key not in self .graph_outputs
172194
173195 def capture (self ,
174- batch_size : int ,
196+ key : Tuple [ int , int , int ] ,
175197 forward_fn : Callable ,
176198 initial_inputs : Dict [str , Any ],
177199 postprocess_fn : Optional [Callable ] = None ):
178200 """Captures the forward pass for a given batch size."""
179201 engine = self ._get_engine ()
180- key = ( batch_size , self . draft_len )
202+ batch_size = key [ 0 ]
181203 # [CUDA graph spec decode padding]
182204 # We pad input IDs/position IDs to the maximum draft length (token per request).
183205 # We're forced to do this because we cannot reallocate inputs over many graph runs.
184- token_per_request = self .draft_len + 1
206+ max_draft_len = key [1 ]
207+ token_per_request = max_draft_len + 1
185208 num_tokens_for_capture = (batch_size * self .max_beam_width *
186209 token_per_request )
187210
@@ -207,30 +230,43 @@ def capture(self,
207230 "spec_metadata" : initial_inputs .get ("spec_metadata" , None ),
208231 }
209232
233+ def _setup_spec_decoding_and_forward (key : Tuple [int , int , int ],
234+ forward_fn : Callable ,
235+ capture_inputs : Dict [str , Any ]):
236+ engine = self ._get_engine ()
237+ # for the first inference of draft model, we need to set the use_spec_decoding to True when capture the graph for multiple runs.
238+ is_first_draft = key [2 ]
239+ needs_kv_cache_recompute = True if engine .enable_spec_decode and engine .spec_config .spec_dec_mode .needs_kv_cache_recompute (
240+ ) else False
241+ if is_first_draft and engine .is_draft_model and needs_kv_cache_recompute :
242+ capture_inputs ['attn_metadata' ].use_spec_decoding = True
243+ return forward_fn (capture_inputs )
244+
210245 # We have to do warm up runs to initialize PyTorch's
211246 # internal states according to the docs:
212247 # https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics
213248 # This also lets us initialize states in the attn_metadata.
214249 graph = torch .cuda .CUDAGraph ()
215250 with with_multi_stream (True ), piecewise_cuda_graph (False ):
216251 for _ in range (self .WARMUP_STEPS ):
217- forward_fn (capture_inputs )
252+ _setup_spec_decoding_and_forward (key , forward_fn ,
253+ capture_inputs )
218254 if postprocess_fn is not None :
219255 postprocess_fn (capture_inputs )
220256 with torch .cuda .graph (graph , pool = self .memory_pool ):
221- output = forward_fn (capture_inputs )
257+ output = _setup_spec_decoding_and_forward (
258+ key , forward_fn , capture_inputs )
222259 if postprocess_fn is not None :
223260 postprocess_fn (capture_inputs )
224261
225262 self .graphs [key ] = graph
226263 self .graph_outputs [key ] = make_weak_ref (output )
227264 self .memory_pool = graph .pool ()
228265
229- def replay (self , batch_size : int ,
266+ def replay (self , key : Tuple [ int , int , int ] ,
230267 current_inputs : Dict [str , Any ]) -> Optional [torch .Tensor ]:
231268 """Replays a previously captured graph."""
232269 engine = self ._get_engine ()
233- key = (batch_size , self .draft_len )
234270 stored_meta = self .graph_metadata [key ]
235271 assert current_inputs ["attn_metadata" ] is stored_meta ["attn_metadata" ]
236272 if stored_meta ["spec_metadata" ] is not None :
0 commit comments