Skip to content

Commit c8a8307

Browse files
authored
Merge branch 'main' into cheng/refactor/sbo
2 parents 852c545 + 254f62d commit c8a8307

File tree

13 files changed

+363
-339
lines changed

13 files changed

+363
-339
lines changed

python/sglang/srt/lora/backend/base_backend.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from typing import Optional, Tuple, Union
1+
from typing import Tuple, Union
22

33
import torch
44

5-
from sglang.srt.lora.utils import LoRABatchInfo
65
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
76

87

@@ -97,8 +96,8 @@ def run_gate_up_lora(
9796

9897
def init_cuda_graph_batch_info(
9998
self,
100-
cuda_graph_batch_info: LoRABatchInfo,
10199
max_bs_in_cuda_graph: int,
100+
num_tokens_per_bs: int,
102101
):
103102
"""Initialize the batch info for CUDA Graph mode.
104103
@@ -108,6 +107,7 @@ def init_cuda_graph_batch_info(
108107
Args:
109108
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
110109
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
110+
num_tokens_per_bs: number of tokens per sequence (1 for decoding, >1 for target_verify)
111111
"""
112112
pass
113113

@@ -117,7 +117,7 @@ def prepare_lora_batch(
117117
weight_indices: list[int],
118118
lora_ranks: list[int],
119119
scalings: list[float],
120-
batch_info: Optional[LoRABatchInfo] = None,
120+
use_cuda_graph: bool,
121121
):
122122
"""Prepare the lora weights and batch info for current forward batch.
123123
@@ -129,7 +129,6 @@ def prepare_lora_batch(
129129
weight_indices: list of indices of lora weights to be applied for current batch
130130
lora_ranks: list of lora ranks corresponding to weight_indices
131131
scalings: list of scaling factors corresponding to weight_indices
132-
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
133-
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
132+
use_cuda_graph: whether to use CUDA Graph for this batch
134133
"""
135134
pass

python/sglang/srt/lora/backend/chunked_backend.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
import torch
42

53
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
@@ -52,7 +50,7 @@ def run_lora_b_sgemm(
5250
output_offset: torch.Tensor,
5351
base_output: torch.Tensor = None,
5452
*args,
55-
**kwargs
53+
**kwargs,
5654
) -> torch.Tensor:
5755
# For simple lora B, we use slice offsets [0, output_dim]
5856
output_dim = weights.shape[-2]
@@ -75,7 +73,7 @@ def run_qkv_lora(
7573
max_qkv_out_dim: int,
7674
base_output: torch.Tensor = None,
7775
*args,
78-
**kwargs
76+
**kwargs,
7977
) -> torch.Tensor:
8078

8179
# x: (s, input_dim)
@@ -107,7 +105,7 @@ def run_gate_up_lora(
107105
output_offset: torch.Tensor,
108106
base_output: torch.Tensor = None,
109107
*args,
110-
**kwargs
108+
**kwargs,
111109
) -> torch.Tensor:
112110

113111
# x: (s, input_dim)
@@ -160,13 +158,36 @@ def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
160158
chunk_size = 16
161159
return min(self.max_chunk_size, chunk_size)
162160

161+
def init_cuda_graph_batch_info(
162+
self,
163+
max_bs_in_cuda_graph: int,
164+
num_tokens_per_bs: int,
165+
):
166+
max_num_segments = (
167+
(num_tokens_per_bs + MIN_CHUNK_SIZE - 1) // MIN_CHUNK_SIZE
168+
) * max_bs_in_cuda_graph
169+
max_num_tokens = max_bs_in_cuda_graph * num_tokens_per_bs
170+
with torch.device("cuda"):
171+
self.cuda_graph_batch_info = LoRABatchInfo(
172+
bs=max_bs_in_cuda_graph,
173+
use_cuda_graph=True,
174+
seg_lens=torch.zeros(max_num_segments, dtype=torch.int32),
175+
seg_indptr=torch.zeros(max_num_segments + 1, dtype=torch.int32),
176+
weight_indices=torch.zeros(max_num_segments, dtype=torch.int32),
177+
permutation=torch.zeros(max_num_tokens, dtype=torch.int32),
178+
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
179+
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
180+
num_segments=None, # Set per batch
181+
max_len=None, # Not used in CSGMV backend
182+
)
183+
163184
def prepare_lora_batch(
164185
self,
165186
forward_batch: ForwardBatch,
166187
weight_indices: list[int],
167188
lora_ranks: list[int],
168189
scalings: list[float],
169-
batch_info: Optional[LoRABatchInfo] = None,
190+
use_cuda_graph: bool,
170191
):
171192
chunk_size = self._determine_chunk_size(forward_batch)
172193

@@ -188,7 +209,7 @@ def prepare_lora_batch(
188209
scalings, dtype=torch.float, pin_memory=True, device="cpu"
189210
)
190211

191-
if batch_info is None:
212+
if not use_cuda_graph:
192213
batch_info = LoRABatchInfo(
193214
bs=forward_batch.batch_size,
194215
num_segments=num_segments,
@@ -213,6 +234,7 @@ def prepare_lora_batch(
213234
seg_lens=None,
214235
)
215236
else:
237+
batch_info = self.cuda_graph_batch_info
216238
batch_info.bs = forward_batch.batch_size
217239
batch_info.num_segments = num_segments
218240
batch_info.max_len = chunk_size
@@ -262,14 +284,23 @@ def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
262284
with torch.device("cpu"):
263285
seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
264286

265-
seg_lens_cpu = (
266-
torch.tensor(
287+
if forward_batch.forward_mode.is_decode():
288+
seg_lens_cpu = torch.ones(forward_batch.batch_size, dtype=torch.int32)
289+
elif forward_batch.forward_mode.is_target_verify():
290+
seg_lens_cpu = torch.full(
291+
size=(forward_batch.batch_size,),
292+
fill_value=forward_batch.spec_info.draft_token_num,
293+
dtype=torch.int32,
294+
)
295+
elif forward_batch.forward_mode.is_extend():
296+
seg_lens_cpu = torch.tensor(
267297
forward_batch.extend_seq_lens_cpu,
268298
dtype=torch.int32,
269299
)
270-
if forward_batch.forward_mode.is_extend()
271-
else torch.ones(forward_batch.batch_size, dtype=torch.int32)
272-
)
300+
else:
301+
raise ValueError(
302+
f"Unsupported forward mode: {forward_batch.forward_mode}"
303+
)
273304

274305
row_weight_indices = torch.repeat_interleave(
275306
seq_weight_indices, seg_lens_cpu

python/sglang/srt/lora/backend/triton_backend.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Optional
2-
31
import torch
42

53
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
@@ -97,24 +95,41 @@ def run_gate_up_lora(
9795
return lora_output
9896

9997
def init_cuda_graph_batch_info(
100-
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
98+
self,
99+
max_bs_in_cuda_graph: int,
100+
num_tokens_per_bs: int,
101101
):
102-
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
103-
# across batches.
104-
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
105-
torch.cumsum(
106-
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
107-
dim=0,
108-
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
109-
)
102+
with torch.device("cuda"):
103+
self.cuda_graph_batch_info = LoRABatchInfo(
104+
bs=max_bs_in_cuda_graph,
105+
use_cuda_graph=True,
106+
num_segments=None,
107+
seg_lens=torch.full(
108+
(max_bs_in_cuda_graph,), num_tokens_per_bs, dtype=torch.int32
109+
),
110+
seg_indptr=torch.empty(max_bs_in_cuda_graph + 1, dtype=torch.int32),
111+
max_len=num_tokens_per_bs,
112+
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
113+
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
114+
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
115+
permutation=None,
116+
)
117+
118+
# Initialize seg_indptr for CUDA graph as they remain constant
119+
# across batches.
120+
torch.cumsum(
121+
self.cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
122+
dim=0,
123+
out=self.cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
124+
)
110125

111126
def prepare_lora_batch(
112127
self,
113128
forward_batch: ForwardBatch,
114129
weight_indices: list[int],
115130
lora_ranks: list[int],
116131
scalings: list[float],
117-
batch_info: Optional[LoRABatchInfo] = None,
132+
use_cuda_graph: bool,
118133
):
119134
# Use pinned memory to avoid synchronizations during host-to-device transfer
120135
weight_indices_tensor = torch.tensor(
@@ -129,10 +144,11 @@ def prepare_lora_batch(
129144

130145
bs = forward_batch.batch_size
131146

132-
if batch_info is not None:
147+
if use_cuda_graph:
133148
assert (
134-
batch_info.use_cuda_graph
135-
), "batch_info.use_cuda_graph must be True when batch_info is provided"
149+
self.cuda_graph_batch_info is not None
150+
), "CUDA Graph batch info is not initialized."
151+
batch_info = self.cuda_graph_batch_info
136152
batch_info.bs = forward_batch.batch_size
137153
batch_info.num_segments = forward_batch.batch_size
138154
else:

python/sglang/srt/lora/lora_manager.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from sglang.srt.lora.lora_registry import LoRARef
3030
from sglang.srt.lora.mem_pool import LoRAMemoryPool
3131
from sglang.srt.lora.utils import (
32-
LoRABatchInfo,
3332
LoRAType,
3433
get_layer_id,
3534
get_normalized_target_modules,
@@ -95,25 +94,13 @@ def __init__(
9594
lora_paths=lora_paths,
9695
)
9796

98-
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
97+
def init_cuda_graph_batch_info(
98+
self, max_bs_in_cuda_graph: int, num_tokens_per_bs: int
99+
):
99100
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
100-
with torch.device("cuda"):
101-
self.cuda_graph_batch_info = LoRABatchInfo(
102-
bs=max_bs_in_cuda_graph,
103-
use_cuda_graph=True,
104-
num_segments=None,
105-
seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
106-
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
107-
max_len=1,
108-
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
109-
permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
110-
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
111-
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
112-
)
113-
114101
self.lora_backend.init_cuda_graph_batch_info(
115-
cuda_graph_batch_info=self.cuda_graph_batch_info,
116102
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
103+
num_tokens_per_bs=num_tokens_per_bs,
117104
)
118105

119106
def create_lora_update_result(
@@ -297,7 +284,7 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
297284
weight_indices=weight_indices,
298285
lora_ranks=lora_ranks,
299286
scalings=scalings,
300-
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
287+
use_cuda_graph=use_cuda_graph,
301288
)
302289

303290
def update_lora_info(self):

python/sglang/srt/lora/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ class LoRABatchInfo:
1919
# Number of segments. For triton backend, it is equal to batch size.
2020
num_segments: int
2121

22-
# Maximum segment length of current batch
23-
max_len: int
24-
2522
# Indice pointers of each segment in shape (num_segments + 1, )
2623
seg_indptr: torch.Tensor
2724

@@ -34,6 +31,9 @@ class LoRABatchInfo:
3431
# scaling of each lora adapter, in shape (lora_num,)
3532
scalings: torch.Tensor
3633

34+
# Maximum segment length of current batch
35+
max_len: Optional[int]
36+
3737
# Lengths of each segments in shape (num_segments,)
3838
seg_lens: Optional[torch.Tensor]
3939

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@ def __init__(self, model_runner: ModelRunner):
308308
set_torch_compile_config()
309309

310310
if self.model_runner.server_args.enable_lora:
311-
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
311+
self.model_runner.lora_manager.init_cuda_graph_batch_info(
312+
max_bs_in_cuda_graph=self.max_bs,
313+
num_tokens_per_bs=self.num_tokens_per_bs,
314+
)
312315

313316
# Graph inputs
314317
with torch.device(self.device):

python/sglang/srt/server_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3874,6 +3874,13 @@ def check_lora_server_args(self):
38743874
)
38753875

38763876
if self.enable_lora:
3877+
# Validate compatibility with speculative decoding
3878+
if self.speculative_algorithm not in ["NGRAM", None]:
3879+
raise ValueError(
3880+
"Currently LoRA is only compatible with NGRAM speculative decoding."
3881+
)
3882+
3883+
# Parse lora_paths
38773884
if isinstance(self.lora_paths, list):
38783885
lora_paths = self.lora_paths
38793886
self.lora_paths = []

python/sglang/test/runners.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,8 @@ def __init__(
528528
speculative_num_steps: Optional[int] = None,
529529
speculative_eagle_topk: Optional[int] = None,
530530
speculative_num_draft_tokens: Optional[int] = None,
531+
speculative_ngram_min_match_window_size: Optional[int] = None,
532+
speculative_ngram_max_match_window_size: Optional[int] = None,
531533
disable_overlap_schedule: bool = False,
532534
disable_custom_all_reduce: bool = False,
533535
torchao_config: Optional[str] = None,
@@ -539,6 +541,7 @@ def __init__(
539541
max_loaded_loras: Optional[int] = None,
540542
json_model_override_args: Optional[dict[str, Any]] = None,
541543
lora_eviction_policy: str = "lru",
544+
enable_deterministic_inference: bool = False,
542545
):
543546
self.model_type = model_type
544547
self.is_generation = model_type == "generation"
@@ -554,6 +557,14 @@ def __init__(
554557
spec_kwargs["speculative_num_steps"] = speculative_num_steps
555558
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
556559
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
560+
elif speculative_algorithm == "NGRAM":
561+
spec_kwargs["speculative_algorithm"] = speculative_algorithm
562+
spec_kwargs["speculative_ngram_min_match_window_size"] = (
563+
speculative_ngram_min_match_window_size
564+
)
565+
spec_kwargs["speculative_ngram_max_match_window_size"] = (
566+
speculative_ngram_max_match_window_size
567+
)
557568

558569
self.engine = Engine(
559570
model_path=model_path,
@@ -594,6 +605,7 @@ def __init__(
594605
else "{}"
595606
),
596607
lora_eviction_policy=lora_eviction_policy,
608+
enable_deterministic_inference=enable_deterministic_inference,
597609
**spec_kwargs,
598610
)
599611

0 commit comments

Comments
 (0)