1- from typing import Optional
2-
31import torch
42
53from sglang .srt .lora .backend .base_backend import BaseLoRABackend
@@ -52,7 +50,7 @@ def run_lora_b_sgemm(
5250 output_offset : torch .Tensor ,
5351 base_output : torch .Tensor = None ,
5452 * args ,
55- ** kwargs
53+ ** kwargs ,
5654 ) -> torch .Tensor :
5755 # For simple lora B, we use slice offsets [0, output_dim]
5856 output_dim = weights .shape [- 2 ]
@@ -75,7 +73,7 @@ def run_qkv_lora(
7573 max_qkv_out_dim : int ,
7674 base_output : torch .Tensor = None ,
7775 * args ,
78- ** kwargs
76+ ** kwargs ,
7977 ) -> torch .Tensor :
8078
8179 # x: (s, input_dim)
@@ -107,7 +105,7 @@ def run_gate_up_lora(
107105 output_offset : torch .Tensor ,
108106 base_output : torch .Tensor = None ,
109107 * args ,
110- ** kwargs
108+ ** kwargs ,
111109 ) -> torch .Tensor :
112110
113111 # x: (s, input_dim)
@@ -160,13 +158,36 @@ def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
160158 chunk_size = 16
161159 return min (self .max_chunk_size , chunk_size )
162160
161+ def init_cuda_graph_batch_info (
162+ self ,
163+ max_bs_in_cuda_graph : int ,
164+ num_tokens_per_bs : int ,
165+ ):
166+ max_num_segments = (
167+ (num_tokens_per_bs + MIN_CHUNK_SIZE - 1 ) // MIN_CHUNK_SIZE
168+ ) * max_bs_in_cuda_graph
169+ max_num_tokens = max_bs_in_cuda_graph * num_tokens_per_bs
170+ with torch .device ("cuda" ):
171+ self .cuda_graph_batch_info = LoRABatchInfo (
172+ bs = max_bs_in_cuda_graph ,
173+ use_cuda_graph = True ,
174+ seg_lens = torch .zeros (max_num_segments , dtype = torch .int32 ),
175+ seg_indptr = torch .zeros (max_num_segments + 1 , dtype = torch .int32 ),
176+ weight_indices = torch .zeros (max_num_segments , dtype = torch .int32 ),
177+ permutation = torch .zeros (max_num_tokens , dtype = torch .int32 ),
178+ lora_ranks = torch .zeros (self .max_loras_per_batch , dtype = torch .int32 ),
179+ scalings = torch .zeros (self .max_loras_per_batch , dtype = torch .float ),
180+ num_segments = None , # Set per batch
181+ max_len = None , # Not used in CSGMV backend
182+ )
183+
163184 def prepare_lora_batch (
164185 self ,
165186 forward_batch : ForwardBatch ,
166187 weight_indices : list [int ],
167188 lora_ranks : list [int ],
168189 scalings : list [float ],
169- batch_info : Optional [ LoRABatchInfo ] = None ,
190+ use_cuda_graph : bool ,
170191 ):
171192 chunk_size = self ._determine_chunk_size (forward_batch )
172193
@@ -188,7 +209,7 @@ def prepare_lora_batch(
188209 scalings , dtype = torch .float , pin_memory = True , device = "cpu"
189210 )
190211
191- if batch_info is None :
212+ if not use_cuda_graph :
192213 batch_info = LoRABatchInfo (
193214 bs = forward_batch .batch_size ,
194215 num_segments = num_segments ,
@@ -213,6 +234,7 @@ def prepare_lora_batch(
213234 seg_lens = None ,
214235 )
215236 else :
237+ batch_info = self .cuda_graph_batch_info
216238 batch_info .bs = forward_batch .batch_size
217239 batch_info .num_segments = num_segments
218240 batch_info .max_len = chunk_size
@@ -262,14 +284,23 @@ def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
262284 with torch .device ("cpu" ):
263285 seq_weight_indices = torch .tensor (seq_weight_indices , dtype = torch .int32 )
264286
265- seg_lens_cpu = (
266- torch .tensor (
287+ if forward_batch .forward_mode .is_decode ():
288+ seg_lens_cpu = torch .ones (forward_batch .batch_size , dtype = torch .int32 )
289+ elif forward_batch .forward_mode .is_target_verify ():
290+ seg_lens_cpu = torch .full (
291+ size = (forward_batch .batch_size ,),
292+ fill_value = forward_batch .spec_info .draft_token_num ,
293+ dtype = torch .int32 ,
294+ )
295+ elif forward_batch .forward_mode .is_extend ():
296+ seg_lens_cpu = torch .tensor (
267297 forward_batch .extend_seq_lens_cpu ,
268298 dtype = torch .int32 ,
269299 )
270- if forward_batch .forward_mode .is_extend ()
271- else torch .ones (forward_batch .batch_size , dtype = torch .int32 )
272- )
300+ else :
301+ raise ValueError (
302+ f"Unsupported forward mode: { forward_batch .forward_mode } "
303+ )
273304
274305 row_weight_indices = torch .repeat_interleave (
275306 seq_weight_indices , seg_lens_cpu
0 commit comments