Skip to content

Commit de4637b

Browse files
committed
support mtp(beta) pd disaggregation and dp attention & draft extend graph(npu) & support dsv3_2 mtp
1 parent 8be0e1b commit de4637b

19 files changed

+666
-341
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,13 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
10291029
# construct fake completed prefill
10301030
new_batch.prepare_for_prebuilt_extend()
10311031
new_batch.process_prebuilt_extend(self.server_args, self.model_config)
1032-
1032+
if self.spec_algorithm.is_eagle() and self.enable_overlap:
1033+
new_batch.spec_info.future_indices = self.future_map.alloc_future_indices(
1034+
len(new_batch.seq_lens)
1035+
)
1036+
self.future_map.store_to_map_for_new_batch(
1037+
new_batch.spec_info.future_indices, new_batch.spec_info
1038+
)
10331039
return new_batch
10341040

10351041
def process_decode_queue(self: Scheduler):

python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,10 @@ def process_prebuilt_extend(
165165
topk_index=topk_index,
166166
hidden_states=hidden_states,
167167
verified_id=self.output_ids,
168+
new_seq_lens=self.seq_lens,
169+
allocate_lens=self.seq_lens,
170+
num_tokens_per_batch=1,
171+
num_tokens_for_logprob_per_batch=1,
168172
)
169173
spec_info.prepare_for_extend(self)
170174
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST

python/sglang/srt/layers/attention/ascend_backend.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,25 @@ def init_forward_metadata_capture_cuda_graph(
161161
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
162162
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
163163
metadata.seq_lens = seq_lens
164-
metadata.actual_seq_lengths_q = torch.tensor(
165-
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
166-
)
164+
if (
165+
forward_mode.is_target_verify()
166+
or forward_mode.is_draft_extend_v2()
167+
or forward_mode.is_draft_extend()
168+
):
169+
metadata.actual_seq_lengths_q = torch.arange(
170+
self.speculative_num_draft_tokens,
171+
self.speculative_num_draft_tokens
172+
+ bs * self.speculative_num_draft_tokens,
173+
self.speculative_num_draft_tokens,
174+
dtype=torch.int32,
175+
device=seq_lens.device,
176+
)
177+
else:
178+
metadata.actual_seq_lengths_q = torch.tensor(
179+
[1 + i * 1 for i in range(bs)],
180+
dtype=torch.int32,
181+
device=seq_lens.device,
182+
)
167183

168184
self.graph_metadata[bs] = metadata
169185
self.forward_metadata = metadata
@@ -193,7 +209,8 @@ def init_forward_metadata_replay_cuda_graph(
193209
)
194210
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
195211
metadata.block_tables[bs:, :].fill_(0)
196-
212+
if forward_mode.is_target_verify():
213+
seq_lens = seq_lens + self.speculative_num_draft_tokens
197214
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
198215

199216
self.forward_metadata = metadata
@@ -217,7 +234,12 @@ def forward_sparse(
217234
topk_indices: torch.Tensor = None,
218235
):
219236

220-
is_prefill = forward_batch.forward_mode.is_extend()
237+
is_prefill = (
238+
forward_batch.forward_mode.is_extend()
239+
and not forward_batch.forward_mode.is_draft_extend_v2()
240+
and not forward_batch.forward_mode.is_draft_extend()
241+
and not forward_batch.forward_mode.is_target_verify()
242+
)
221243

222244
if save_kv_cache:
223245
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
@@ -232,9 +254,30 @@ def forward_sparse(
232254
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
233255
else:
234256
if self.forward_metadata.actual_seq_lengths_q is None:
235-
actual_seq_qlen = (
236-
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
237-
)
257+
if (
258+
forward_batch.forward_mode.is_draft_extend_v2()
259+
or forward_batch.forward_mode.is_target_verify()
260+
):
261+
actual_seq_qlen = (
262+
torch.arange(
263+
self.speculative_num_draft_tokens,
264+
self.speculative_num_draft_tokens + q.shape[0],
265+
self.speculative_num_draft_tokens,
266+
dtype=torch.int32,
267+
)
268+
.to(q.device)
269+
.to(torch.int32)
270+
)
271+
elif forward_batch.forward_mode.is_draft_extend():
272+
actual_seq_qlen = (
273+
forward_batch.extend_seq_lens.cumsum()
274+
.to(q.device)
275+
.to(torch.int32)
276+
)
277+
else:
278+
actual_seq_qlen = (
279+
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
280+
)
238281
else:
239282
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
240283
if self.forward_metadata.seq_lens_cpu_int is None:
@@ -914,7 +957,7 @@ def call_fn(i, forward_batch):
914957
encoder_lens=None,
915958
forward_mode=ForwardMode.DECODE,
916959
spec_info=forward_batch.spec_info,
917-
seq_lens_cpu=None,
960+
seq_lens_cpu=forward_batch.seq_lens_cpu,
918961
)
919962

920963
self.common_template(forward_batch, call_fn)

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,12 @@ def forward_npu(
666666
enable_index_cp = (
667667
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
668668
)
669-
is_prefill = forward_batch.forward_mode.is_extend()
669+
is_prefill = (
670+
forward_batch.forward_mode.is_extend()
671+
and not forward_batch.forward_mode.is_draft_extend_v2()
672+
and not forward_batch.forward_mode.is_target_verify()
673+
and not forward_batch.forward_mode.is_draft_extend()
674+
)
670675

671676
attention_tp_rank = get_attention_tp_rank()
672677
attention_tp_size = get_attention_tp_size()
@@ -757,9 +762,27 @@ def forward_npu(
757762

758763
else:
759764
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
760-
actual_seq_lengths_q = torch.tensor(
761-
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
762-
)
765+
if (
766+
forward_batch.forward_mode.is_draft_extend_v2()
767+
or forward_batch.forward_mode.is_target_verify()
768+
or forward_batch.forward_mode.is_draft_extend()
769+
):
770+
num_draft_tokens = (
771+
forward_batch.attn_backend.speculative_num_draft_tokens
772+
)
773+
actual_seq_lengths_q = torch.arange(
774+
num_draft_tokens,
775+
num_draft_tokens + bs,
776+
num_draft_tokens,
777+
dtype=torch.int32,
778+
device=k.device,
779+
)
780+
else:
781+
actual_seq_lengths_q = torch.tensor(
782+
[1 + i * 1 for i in range(bs)],
783+
dtype=torch.int32,
784+
device=k.device,
785+
)
763786
else:
764787
actual_seq_lengths_q = (
765788
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q

python/sglang/srt/managers/overlap_utils.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,12 @@ def resolve_future(self, model_worker_batch: ModelWorkerBatch):
102102
if self.spec_algo.is_eagle():
103103
# TODO(lsyin): write future indices into spec_info.future_indices
104104
draft_input: EagleDraftInput = model_worker_batch.spec_info
105-
if draft_input is None:
105+
if (
106+
draft_input is None
107+
or model_worker_batch.forward_mode.is_idle()
108+
or draft_input.future_indices is None
109+
or not self.buf_initialized
110+
):
106111
# FIXME(lsyin): No future exists, only for prefill batch, not compatible with mixed mode
107112
return
108113
indices = draft_input.future_indices.indices
@@ -114,17 +119,42 @@ def resolve_future(self, model_worker_batch: ModelWorkerBatch):
114119
else:
115120
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
116121

122+
def is_empty_slice(self, s: slice) -> bool:
123+
start, stop, step = s.indices(self.future_buffer_len)
124+
if step > 0:
125+
return start >= stop
126+
else:
127+
return start <= stop
128+
117129
def store_to_map(
118130
self, future_indices: FutureIndices, batch_result: GenerationBatchResult
119131
):
120132
intv = future_indices.interval
121133
if self.spec_algo.is_eagle():
134+
if self.is_empty_slice(intv):
135+
return
122136
draft_input: EagleDraftInput = batch_result.next_draft_input
123137
self._lazy_init_buf(draft_input)
124-
self.topk_p_buf[intv] = draft_input.topk_p
125-
self.topk_index_buf[intv] = draft_input.topk_index
126-
self.hidden_states_buf[intv] = draft_input.hidden_states
127-
self.verified_id_buf[intv] = draft_input.verified_id
128-
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
138+
if self.buf_initialized:
139+
self.topk_p_buf[intv] = draft_input.topk_p
140+
self.topk_index_buf[intv] = draft_input.topk_index
141+
self.hidden_states_buf[intv] = draft_input.hidden_states
142+
self.verified_id_buf[intv] = draft_input.verified_id
143+
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
129144
else:
130145
self.token_ids_buf[intv] = batch_result.next_token_ids
146+
147+
def store_to_map_for_new_batch(
148+
self, future_indices: FutureIndices, draft_input: EagleDraftInput
149+
):
150+
intv = future_indices.interval
151+
if self.spec_algo.is_eagle():
152+
if self.is_empty_slice(intv):
153+
return
154+
self._lazy_init_buf(draft_input)
155+
if self.buf_initialized:
156+
self.topk_p_buf[intv] = draft_input.topk_p
157+
self.topk_index_buf[intv] = draft_input.topk_index
158+
self.hidden_states_buf[intv] = draft_input.hidden_states
159+
self.verified_id_buf[intv] = draft_input.verified_id
160+
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens

python/sglang/srt/managers/scheduler.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -859,14 +859,19 @@ def init_disaggregation(self):
859859
)
860860

861861
# The decode requests pending for pre-allocation
862+
if self.draft_worker is None or self.spec_algorithm.is_ngram():
863+
draft_token_to_kv_pool = None
864+
elif self.spec_algorithm.is_eagle() and self.enable_overlap:
865+
draft_token_to_kv_pool = (
866+
self.draft_worker.draft_worker.draft_runner.token_to_kv_pool
867+
)
868+
else:
869+
draft_token_to_kv_pool = self.draft_worker.model_runner.token_to_kv_pool
870+
862871
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
863872
req_to_token_pool=self.req_to_token_pool,
864873
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
865-
draft_token_to_kv_pool=(
866-
None
867-
if self.draft_worker is None or self.spec_algorithm.is_ngram()
868-
else self.draft_worker.model_runner.token_to_kv_pool
869-
),
874+
draft_token_to_kv_pool=draft_token_to_kv_pool,
870875
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
871876
metadata_buffers=self.disagg_metadata_buffers,
872877
scheduler=self,
@@ -897,13 +902,18 @@ def init_disaggregation(self):
897902
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
898903
)
899904

905+
if self.draft_worker is None or self.spec_algorithm.is_ngram():
906+
draft_token_to_kv_pool = None
907+
elif self.spec_algorithm.is_eagle() and self.enable_overlap:
908+
draft_token_to_kv_pool = (
909+
self.draft_worker.draft_worker.draft_runner.token_to_kv_pool
910+
)
911+
else:
912+
draft_token_to_kv_pool = self.draft_worker.model_runner.token_to_kv_pool
913+
900914
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
901915
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
902-
draft_token_to_kv_pool=(
903-
None
904-
if self.draft_worker is None or self.spec_algorithm.is_ngram()
905-
else self.draft_worker.model_runner.token_to_kv_pool
906-
),
916+
draft_token_to_kv_pool=draft_token_to_kv_pool,
907917
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
908918
metadata_buffers=self.disagg_metadata_buffers,
909919
tp_rank=self.tp_rank,

python/sglang/srt/model_executor/forward_batch_info.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,10 @@ def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
828828
self.spec_info.accept_length = self.spec_info.accept_length[:bs]
829829
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
830830
logits_output.hidden_states = logits_output.hidden_states[:bs]
831+
elif self.forward_mode.is_draft_extend_v2(): # draft extend_v2
832+
bs = bs * self.spec_info.num_tokens_per_batch
833+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
834+
logits_output.hidden_states = logits_output.hidden_states[:bs]
831835
elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
832836
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
833837
logits_output.hidden_states = logits_output.hidden_states[:bs]

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2200,7 +2200,7 @@ def _forward_raw(
22002200
reinit_attn_backend=reinit_attn_backend,
22012201
forward_count=split_forward_count,
22022202
)
2203-
elif forward_batch.forward_mode.is_extend():
2203+
elif forward_batch.forward_mode.is_extend(include_draft_extend_v2=True):
22042204
ret = self.forward_extend(
22052205
forward_batch,
22062206
skip_attn_backend_init=skip_attn_backend_init,

python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def __init__(self, eagle_worker: EAGLEWorker):
8888
set_torch_compile_config()
8989

9090
# Graph inputs
91-
with torch.device("cuda"):
91+
with torch.device(model_runner.device):
9292
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
9393
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
9494
self.seq_lens = torch.full(
@@ -157,11 +157,28 @@ def can_run(self, forward_batch: ForwardBatch):
157157

158158
return is_bs_supported
159159

160+
def _create_graph(self):
161+
return torch.cuda.CUDAGraph()
162+
163+
def _capture_init(self, run_once_fn):
164+
for _ in range(2):
165+
torch.cuda.synchronize()
166+
self.model_runner.tp_group.barrier()
167+
run_once_fn()
168+
169+
def _capture_graph(self, graph, pool, stream, run_once_fn):
170+
with torch.cuda.graph(graph, pool=pool, stream=stream):
171+
out = run_once_fn()
172+
return out
173+
174+
def _update_and_replay(self, forward_batch: ForwardBatch):
175+
self.graphs[self.bs].replay()
176+
160177
def capture(self):
161178
CudaGraphRunner.capture(self)
162179

163180
def capture_one_batch_size(self, num_seqs: int, forward: Callable):
164-
graph = torch.cuda.CUDAGraph()
181+
graph = self._create_graph()
165182
stream = self.stream
166183
num_tokens = num_seqs * self.num_tokens_per_bs
167184

@@ -282,16 +299,10 @@ def run_once():
282299

283300
self.deepep_adapter.capture(is_extend_in_batch=False)
284301

285-
for _ in range(2):
286-
torch.cuda.synchronize()
287-
self.model_runner.tp_group.barrier()
288-
289-
run_once()
290-
291-
with torch.cuda.graph(
292-
graph, pool=get_global_graph_memory_pool(), stream=stream
293-
):
294-
out = run_once()
302+
self._capture_init(run_once)
303+
out = self._capture_graph(
304+
graph, get_global_graph_memory_pool(), stream, run_once
305+
)
295306

296307
set_global_graph_memory_pool(graph.pool())
297308
return graph, out
@@ -359,10 +370,12 @@ def replay(self, forward_batch: ForwardBatch):
359370
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
360371
forward_batch, bs
361372
)
373+
self.raw_bs = raw_bs
374+
self.bs = bs
362375
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
363376

364377
# Replay
365-
self.graphs[bs].replay()
378+
self._update_and_replay(forward_batch)
366379
out = self.output_buffers[bs]
367380

368381
if bs != raw_bs:

0 commit comments

Comments
 (0)