Skip to content

Commit 2aec8b6

Browse files
[Feature] Spec-Overlap supporting DP-ATTN; PD-Disaggregation; npugraph mode (#12443)
1 parent 0d41ddf commit 2aec8b6

24 files changed

+656
-365
lines changed

.github/workflows/release-docker-npu-nightly.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,6 @@ jobs:
7373
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
7474
provenance: false
7575
build-args: |
76-
SGLANG_KERNEL_NPU_TAG=20251030
76+
SGLANG_KERNEL_NPU_TAG=20251110
7777
CANN_VERSION=${{ matrix.cann_version }}
7878
DEVICE_TYPE=${{ matrix.device_type }}

.github/workflows/release-docker-npu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,6 @@ jobs:
6969
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
7070
provenance: false
7171
build-args: |
72-
SGLANG_KERNEL_NPU_TAG=20251030
72+
SGLANG_KERNEL_NPU_TAG=20251110
7373
CANN_VERSION=${{ matrix.cann_version }}
7474
DEVICE_TYPE=${{ matrix.device_type }}

python/sglang/srt/disaggregation/decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,7 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
930930

931931
# construct fake completed prefill
932932
new_batch.prepare_for_prebuilt()
933-
new_batch.process_prebuilt(self.server_args, self.model_config)
933+
new_batch.process_prebuilt(self.server_args, self.future_map)
934934

935935
return new_batch
936936

python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
logger = logging.getLogger(__name__)
1515

1616
if TYPE_CHECKING:
17-
from sglang.srt.configs.model_config import ModelConfig
17+
from sglang.srt.managers.overlap_utils import FutureMap
1818
from sglang.srt.managers.schedule_batch import ScheduleBatch
1919
from sglang.srt.server_args import ServerArgs
2020

@@ -102,7 +102,9 @@ def prepare_for_prebuilt(self: ScheduleBatch):
102102
)
103103

104104
def process_prebuilt(
105-
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
105+
self: ScheduleBatch,
106+
server_args: ServerArgs,
107+
future_map: FutureMap,
106108
):
107109
"""Assign the buffered last input id to schedule batch"""
108110
self.output_ids = []
@@ -166,7 +168,16 @@ def process_prebuilt(
166168
topk_index=topk_index,
167169
hidden_states=hidden_states,
168170
verified_id=self.output_ids,
171+
new_seq_lens=self.seq_lens,
172+
allocate_lens=self.seq_lens,
169173
)
170174
spec_info.prepare_for_extend(self)
171175
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
176+
if self.enable_overlap:
177+
spec_info.future_indices = future_map.alloc_future_indices(
178+
len(self.seq_lens)
179+
)
180+
future_map.store_to_map_for_new_batch(
181+
spec_info.future_indices, spec_info
182+
)
172183
self.spec_info = spec_info

python/sglang/srt/layers/attention/ascend_backend.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,25 @@ def init_forward_metadata_capture_cuda_graph(
161161
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
162162
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
163163
metadata.seq_lens = seq_lens
164-
metadata.actual_seq_lengths_q = torch.tensor(
165-
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
166-
)
164+
if (
165+
forward_mode.is_target_verify()
166+
or forward_mode.is_draft_extend_v2()
167+
or forward_mode.is_draft_extend()
168+
):
169+
metadata.actual_seq_lengths_q = torch.arange(
170+
self.speculative_num_draft_tokens,
171+
self.speculative_num_draft_tokens
172+
+ bs * self.speculative_num_draft_tokens,
173+
self.speculative_num_draft_tokens,
174+
dtype=torch.int32,
175+
device=seq_lens.device,
176+
)
177+
else:
178+
metadata.actual_seq_lengths_q = torch.tensor(
179+
[1 + i * 1 for i in range(bs)],
180+
dtype=torch.int32,
181+
device=seq_lens.device,
182+
)
167183

168184
self.graph_metadata[bs] = metadata
169185
self.forward_metadata = metadata
@@ -193,7 +209,8 @@ def init_forward_metadata_replay_cuda_graph(
193209
)
194210
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
195211
metadata.block_tables[bs:, :].fill_(0)
196-
212+
if forward_mode.is_target_verify():
213+
seq_lens = seq_lens + self.speculative_num_draft_tokens
197214
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
198215

199216
self.forward_metadata = metadata
@@ -217,7 +234,12 @@ def forward_sparse(
217234
topk_indices: torch.Tensor = None,
218235
):
219236

220-
is_prefill = forward_batch.forward_mode.is_extend()
237+
is_prefill = (
238+
forward_batch.forward_mode.is_extend()
239+
and not forward_batch.forward_mode.is_draft_extend_v2()
240+
and not forward_batch.forward_mode.is_draft_extend()
241+
and not forward_batch.forward_mode.is_target_verify()
242+
)
221243

222244
if save_kv_cache:
223245
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
@@ -232,9 +254,30 @@ def forward_sparse(
232254
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
233255
else:
234256
if self.forward_metadata.actual_seq_lengths_q is None:
235-
actual_seq_qlen = (
236-
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
237-
)
257+
if (
258+
forward_batch.forward_mode.is_draft_extend_v2()
259+
or forward_batch.forward_mode.is_target_verify()
260+
):
261+
actual_seq_qlen = (
262+
torch.arange(
263+
self.speculative_num_draft_tokens,
264+
self.speculative_num_draft_tokens + q.shape[0],
265+
self.speculative_num_draft_tokens,
266+
dtype=torch.int32,
267+
)
268+
.to(q.device)
269+
.to(torch.int32)
270+
)
271+
elif forward_batch.forward_mode.is_draft_extend():
272+
actual_seq_qlen = (
273+
forward_batch.extend_seq_lens.cumsum()
274+
.to(q.device)
275+
.to(torch.int32)
276+
)
277+
else:
278+
actual_seq_qlen = (
279+
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
280+
)
238281
else:
239282
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
240283
if self.forward_metadata.seq_lens_cpu_int is None:
@@ -477,7 +520,7 @@ def forward_mtp(
477520
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
478521
)
479522

480-
q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank)
523+
q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank).contiguous()
481524
q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)
482525
if not self.graph_mode:
483526
num_token_padding = q.shape[0]
@@ -919,7 +962,7 @@ def call_fn(i, forward_batch):
919962
encoder_lens=None,
920963
forward_mode=ForwardMode.DECODE,
921964
spec_info=forward_batch.spec_info,
922-
seq_lens_cpu=None,
965+
seq_lens_cpu=forward_batch.seq_lens_cpu,
923966
)
924967

925968
self.common_template(forward_batch, call_fn)

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,12 @@ def forward_npu(
699699
enable_index_cp = (
700700
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
701701
)
702-
is_prefill = forward_batch.forward_mode.is_extend()
702+
is_prefill = (
703+
forward_batch.forward_mode.is_extend()
704+
and not forward_batch.forward_mode.is_draft_extend_v2()
705+
and not forward_batch.forward_mode.is_target_verify()
706+
and not forward_batch.forward_mode.is_draft_extend()
707+
)
703708

704709
attention_tp_rank = get_attention_tp_rank()
705710
attention_tp_size = get_attention_tp_size()
@@ -790,9 +795,27 @@ def forward_npu(
790795

791796
else:
792797
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
793-
actual_seq_lengths_q = torch.tensor(
794-
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
795-
)
798+
if (
799+
forward_batch.forward_mode.is_draft_extend_v2()
800+
or forward_batch.forward_mode.is_target_verify()
801+
or forward_batch.forward_mode.is_draft_extend()
802+
):
803+
num_draft_tokens = (
804+
forward_batch.attn_backend.speculative_num_draft_tokens
805+
)
806+
actual_seq_lengths_q = torch.arange(
807+
num_draft_tokens,
808+
num_draft_tokens + bs,
809+
num_draft_tokens,
810+
dtype=torch.int32,
811+
device=k.device,
812+
)
813+
else:
814+
actual_seq_lengths_q = torch.tensor(
815+
[1 + i * 1 for i in range(bs)],
816+
dtype=torch.int32,
817+
device=k.device,
818+
)
796819
else:
797820
actual_seq_lengths_q = (
798821
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q

python/sglang/srt/managers/overlap_utils.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,33 @@ def resolve_future(self, model_worker_batch: ModelWorkerBatch):
114114
else:
115115
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
116116

117+
def is_empty_slice(self, s: slice) -> bool:
118+
start, stop, step = s.indices(self.future_buffer_len)
119+
if step > 0:
120+
return start >= stop
121+
else:
122+
return start <= stop
123+
117124
def store_to_map(
118125
self, future_indices: FutureIndices, batch_result: GenerationBatchResult
119126
):
120-
intv = future_indices.interval
121127
if self.spec_algo.is_eagle():
122128
draft_input: EagleDraftInput = batch_result.next_draft_input
123-
self._lazy_init_buf(draft_input)
124-
self.topk_p_buf[intv] = draft_input.topk_p
125-
self.topk_index_buf[intv] = draft_input.topk_index
126-
self.hidden_states_buf[intv] = draft_input.hidden_states
127-
self.verified_id_buf[intv] = draft_input.verified_id
128-
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
129+
self.store_to_map_for_new_batch(future_indices, draft_input)
129130
else:
131+
intv = future_indices.interval
130132
self.token_ids_buf[intv] = batch_result.next_token_ids
133+
134+
def store_to_map_for_new_batch(
135+
self, future_indices: FutureIndices, draft_input: EagleDraftInput
136+
):
137+
intv = future_indices.interval
138+
# idle indices do not need store info
139+
if self.is_empty_slice(intv):
140+
return
141+
self._lazy_init_buf(draft_input)
142+
self.topk_p_buf[intv] = draft_input.topk_p
143+
self.topk_index_buf[intv] = draft_input.topk_index
144+
self.hidden_states_buf[intv] = draft_input.hidden_states
145+
self.verified_id_buf[intv] = draft_input.verified_id
146+
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens

python/sglang/srt/managers/scheduler.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,15 @@ def init_disaggregation(self):
846846
self.server_args.disaggregation_transfer_backend
847847
)
848848

849+
if self.draft_worker is None or self.spec_algorithm.is_ngram():
850+
draft_token_to_kv_pool = None
851+
elif self.spec_algorithm.is_eagle() and self.enable_overlap:
852+
draft_token_to_kv_pool = (
853+
self.draft_worker.draft_worker.draft_runner.token_to_kv_pool
854+
)
855+
else:
856+
draft_token_to_kv_pool = self.draft_worker.model_runner.token_to_kv_pool
857+
849858
if (
850859
self.disaggregation_mode == DisaggregationMode.DECODE
851860
): # *2 for the headroom.
@@ -874,11 +883,7 @@ def init_disaggregation(self):
874883
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
875884
req_to_token_pool=self.req_to_token_pool,
876885
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
877-
draft_token_to_kv_pool=(
878-
None
879-
if self.draft_worker is None or self.spec_algorithm.is_ngram()
880-
else self.draft_worker.model_runner.token_to_kv_pool
881-
),
886+
draft_token_to_kv_pool=draft_token_to_kv_pool,
882887
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
883888
metadata_buffers=self.disagg_metadata_buffers,
884889
scheduler=self,
@@ -911,11 +916,7 @@ def init_disaggregation(self):
911916

912917
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
913918
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
914-
draft_token_to_kv_pool=(
915-
None
916-
if self.draft_worker is None or self.spec_algorithm.is_ngram()
917-
else self.draft_worker.model_runner.token_to_kv_pool
918-
),
919+
draft_token_to_kv_pool=draft_token_to_kv_pool,
919920
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
920921
metadata_buffers=self.disagg_metadata_buffers,
921922
tp_rank=self.tp_rank,
@@ -935,6 +936,8 @@ def init_disaggregation(self):
935936
self.disagg_prefill_inflight_queue: List[Req] = []
936937

937938
def init_overlap(self):
939+
self.future_map = None
940+
938941
if not self.enable_overlap:
939942
return
940943

python/sglang/srt/model_executor/forward_batch_info.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,10 @@ def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
866866
self.spec_info.accept_length = self.spec_info.accept_length[:bs]
867867
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
868868
logits_output.hidden_states = logits_output.hidden_states[:bs]
869+
elif self.forward_mode.is_draft_extend_v2(): # draft extend_v2
870+
bs = bs * self.spec_info.num_tokens_per_batch
871+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
872+
logits_output.hidden_states = logits_output.hidden_states[:bs]
869873
elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
870874
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
871875
logits_output.hidden_states = logits_output.hidden_states[:bs]

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2236,7 +2236,7 @@ def _forward_raw(
22362236
reinit_attn_backend=reinit_attn_backend,
22372237
forward_count=split_forward_count,
22382238
)
2239-
elif forward_batch.forward_mode.is_extend():
2239+
elif forward_batch.forward_mode.is_extend(include_draft_extend_v2=True):
22402240
ret = self.forward_extend(
22412241
forward_batch,
22422242
skip_attn_backend_init=skip_attn_backend_init,

0 commit comments

Comments
 (0)