@@ -161,9 +161,25 @@ def init_forward_metadata_capture_cuda_graph(
161161 metadata .block_tables = self .graph_metadata ["block_tables" ][:bs , :]
162162 metadata .seq_lens_cpu_list = seq_lens .cpu ().int ().tolist ()
163163 metadata .seq_lens = seq_lens
164- metadata .actual_seq_lengths_q = torch .tensor (
165- [1 + i * 1 for i in range (bs )], dtype = torch .int32 , device = seq_lens .device
166- )
164+ if (
165+ forward_mode .is_target_verify ()
166+ or forward_mode .is_draft_extend_v2 ()
167+ or forward_mode .is_draft_extend ()
168+ ):
169+ metadata .actual_seq_lengths_q = torch .arange (
170+ self .speculative_num_draft_tokens ,
171+ self .speculative_num_draft_tokens
172+ + bs * self .speculative_num_draft_tokens ,
173+ self .speculative_num_draft_tokens ,
174+ dtype = torch .int32 ,
175+ device = seq_lens .device ,
176+ )
177+ else :
178+ metadata .actual_seq_lengths_q = torch .tensor (
179+ [1 + i * 1 for i in range (bs )],
180+ dtype = torch .int32 ,
181+ device = seq_lens .device ,
182+ )
167183
168184 self .graph_metadata [bs ] = metadata
169185 self .forward_metadata = metadata
@@ -193,7 +209,8 @@ def init_forward_metadata_replay_cuda_graph(
193209 )
194210 metadata .block_tables [:bs , max_seq_pages :].fill_ (0 )
195211 metadata .block_tables [bs :, :].fill_ (0 )
196-
212+ if forward_mode .is_target_verify ():
213+ seq_lens = seq_lens + self .speculative_num_draft_tokens
197214 metadata .seq_lens [:bs ].copy_ (seq_lens [:bs ])
198215
199216 self .forward_metadata = metadata
@@ -217,7 +234,12 @@ def forward_sparse(
217234 topk_indices : torch .Tensor = None ,
218235 ):
219236
220- is_prefill = forward_batch .forward_mode .is_extend ()
237+ is_prefill = (
238+ forward_batch .forward_mode .is_extend ()
239+ and not forward_batch .forward_mode .is_draft_extend_v2 ()
240+ and not forward_batch .forward_mode .is_draft_extend ()
241+ and not forward_batch .forward_mode .is_target_verify ()
242+ )
221243
222244 if save_kv_cache :
223245 k = k .view (- 1 , layer .tp_k_head_num , self .kv_lora_rank )
@@ -232,9 +254,30 @@ def forward_sparse(
232254 actual_seq_qlen = torch .cumsum (forward_batch .seq_lens , dim = 0 )
233255 else :
234256 if self .forward_metadata .actual_seq_lengths_q is None :
235- actual_seq_qlen = (
236- torch .arange (1 , q .shape [0 ] + 1 ).to (q .device ).to (torch .int32 )
237- )
257+ if (
258+ forward_batch .forward_mode .is_draft_extend_v2 ()
259+ or forward_batch .forward_mode .is_target_verify ()
260+ ):
261+ actual_seq_qlen = (
262+ torch .arange (
263+ self .speculative_num_draft_tokens ,
264+ self .speculative_num_draft_tokens + q .shape [0 ],
265+ self .speculative_num_draft_tokens ,
266+ dtype = torch .int32 ,
267+ )
268+ .to (q .device )
269+ .to (torch .int32 )
270+ )
271+ elif forward_batch .forward_mode .is_draft_extend ():
272+ actual_seq_qlen = (
273+ forward_batch .extend_seq_lens .cumsum ()
274+ .to (q .device )
275+ .to (torch .int32 )
276+ )
277+ else :
278+ actual_seq_qlen = (
279+ torch .arange (1 , q .shape [0 ] + 1 ).to (q .device ).to (torch .int32 )
280+ )
238281 else :
239282 actual_seq_qlen = self .forward_metadata .actual_seq_lengths_q
240283 if self .forward_metadata .seq_lens_cpu_int is None :
@@ -477,7 +520,7 @@ def forward_mtp(
477520 - 1 , layer .tp_v_head_num , self .page_size , self .kv_lora_rank
478521 )
479522
480- q_nope = q .view (- 1 , layer .tp_q_head_num , self .kv_lora_rank )
523+ q_nope = q .view (- 1 , layer .tp_q_head_num , self .kv_lora_rank ). contiguous ()
481524 q_rope = q_rope .view (- 1 , layer .tp_q_head_num , self .qk_rope_head_dim )
482525 if not self .graph_mode :
483526 num_token_padding = q .shape [0 ]
@@ -919,7 +962,7 @@ def call_fn(i, forward_batch):
919962 encoder_lens = None ,
920963 forward_mode = ForwardMode .DECODE ,
921964 spec_info = forward_batch .spec_info ,
922- seq_lens_cpu = None ,
965+ seq_lens_cpu = forward_batch . seq_lens_cpu ,
923966 )
924967
925968 self .common_template (forward_batch , call_fn )
0 commit comments