Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/release-docker-npu-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ jobs:
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance: false
build-args: |
SGLANG_KERNEL_NPU_TAG=20251030
SGLANG_KERNEL_NPU_TAG=20251110
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
2 changes: 1 addition & 1 deletion .github/workflows/release-docker-npu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ jobs:
push: ${{ github.repository == 'sgl-project/sglang' && github.event_name != 'pull_request' }}
provenance: false
build-args: |
SGLANG_KERNEL_NPU_TAG=20251030
SGLANG_KERNEL_NPU_TAG=20251110
CANN_VERSION=${{ matrix.cann_version }}
DEVICE_TYPE=${{ matrix.device_type }}
2 changes: 1 addition & 1 deletion python/sglang/srt/disaggregation/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:

# construct fake completed prefill
new_batch.prepare_for_prebuilt()
new_batch.process_prebuilt(self.server_args, self.model_config)
new_batch.process_prebuilt(self.server_args, self.future_map)

return new_batch

Expand Down
15 changes: 13 additions & 2 deletions python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.server_args import ServerArgs

Expand Down Expand Up @@ -102,7 +102,9 @@ def prepare_for_prebuilt(self: ScheduleBatch):
)

def process_prebuilt(
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
self: ScheduleBatch,
server_args: ServerArgs,
future_map: FutureMap,
):
"""Assign the buffered last input id to schedule batch"""
self.output_ids = []
Expand Down Expand Up @@ -166,7 +168,16 @@ def process_prebuilt(
topk_index=topk_index,
hidden_states=hidden_states,
verified_id=self.output_ids,
new_seq_lens=self.seq_lens,
allocate_lens=self.seq_lens,
)
spec_info.prepare_for_extend(self)
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
if self.enable_overlap:
spec_info.future_indices = future_map.alloc_future_indices(
len(self.seq_lens)
)
future_map.store_to_map_for_new_batch(
spec_info.future_indices, spec_info
)
self.spec_info = spec_info
63 changes: 53 additions & 10 deletions python/sglang/srt/layers/attention/ascend_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,25 @@ def init_forward_metadata_capture_cuda_graph(
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
metadata.seq_lens = seq_lens
metadata.actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
)
if (
forward_mode.is_target_verify()
or forward_mode.is_draft_extend_v2()
or forward_mode.is_draft_extend()
):
metadata.actual_seq_lengths_q = torch.arange(
self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens
+ bs * self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=seq_lens.device,
)
else:
metadata.actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)],
dtype=torch.int32,
device=seq_lens.device,
)

self.graph_metadata[bs] = metadata
self.forward_metadata = metadata
Expand Down Expand Up @@ -193,7 +209,8 @@ def init_forward_metadata_replay_cuda_graph(
)
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
metadata.block_tables[bs:, :].fill_(0)

if forward_mode.is_target_verify():
seq_lens = seq_lens + self.speculative_num_draft_tokens
metadata.seq_lens[:bs].copy_(seq_lens[:bs])

self.forward_metadata = metadata
Expand All @@ -217,7 +234,12 @@ def forward_sparse(
topk_indices: torch.Tensor = None,
):

is_prefill = forward_batch.forward_mode.is_extend()
is_prefill = (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_draft_extend_v2()
and not forward_batch.forward_mode.is_draft_extend()
and not forward_batch.forward_mode.is_target_verify()
)

if save_kv_cache:
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
Expand All @@ -232,9 +254,30 @@ def forward_sparse(
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
else:
if self.forward_metadata.actual_seq_lengths_q is None:
actual_seq_qlen = (
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
)
if (
forward_batch.forward_mode.is_draft_extend_v2()
or forward_batch.forward_mode.is_target_verify()
):
actual_seq_qlen = (
torch.arange(
self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens + q.shape[0],
self.speculative_num_draft_tokens,
dtype=torch.int32,
)
.to(q.device)
.to(torch.int32)
)
elif forward_batch.forward_mode.is_draft_extend():
actual_seq_qlen = (
forward_batch.extend_seq_lens.cumsum()
.to(q.device)
.to(torch.int32)
)
else:
actual_seq_qlen = (
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
)
else:
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
if self.forward_metadata.seq_lens_cpu_int is None:
Expand Down Expand Up @@ -477,7 +520,7 @@ def forward_mtp(
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
)

q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank)
q_nope = q.view(-1, layer.tp_q_head_num, self.kv_lora_rank).contiguous()
q_rope = q_rope.view(-1, layer.tp_q_head_num, self.qk_rope_head_dim)
if not self.graph_mode:
num_token_padding = q.shape[0]
Expand Down Expand Up @@ -919,7 +962,7 @@ def call_fn(i, forward_batch):
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=None,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)

self.common_template(forward_batch, call_fn)
31 changes: 27 additions & 4 deletions python/sglang/srt/layers/attention/nsa/nsa_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,12 @@ def forward_npu(
enable_index_cp = (
get_bool_env_var("SGLANG_USE_AG_AFTER_QLORA") and layer_id >= 4
)
is_prefill = forward_batch.forward_mode.is_extend()
is_prefill = (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_draft_extend_v2()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
)

attention_tp_rank = get_attention_tp_rank()
attention_tp_size = get_attention_tp_size()
Expand Down Expand Up @@ -784,9 +789,27 @@ def forward_npu(

else:
if forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q is None:
actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=k.device
)
if (
forward_batch.forward_mode.is_draft_extend_v2()
or forward_batch.forward_mode.is_target_verify()
or forward_batch.forward_mode.is_draft_extend()
):
num_draft_tokens = (
forward_batch.attn_backend.speculative_num_draft_tokens
)
actual_seq_lengths_q = torch.arange(
num_draft_tokens,
num_draft_tokens + bs,
num_draft_tokens,
dtype=torch.int32,
device=k.device,
)
else:
actual_seq_lengths_q = torch.tensor(
[1 + i * 1 for i in range(bs)],
dtype=torch.int32,
device=k.device,
)
else:
actual_seq_lengths_q = (
forward_batch.attn_backend.forward_metadata.actual_seq_lengths_q
Expand Down
30 changes: 23 additions & 7 deletions python/sglang/srt/managers/overlap_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,33 @@ def resolve_future(self, model_worker_batch: ModelWorkerBatch):
else:
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)

def is_empty_slice(self, s: slice) -> bool:
start, stop, step = s.indices(self.future_buffer_len)
if step > 0:
return start >= stop
else:
return start <= stop

def store_to_map(
self, future_indices: FutureIndices, batch_result: GenerationBatchResult
):
intv = future_indices.interval
if self.spec_algo.is_eagle():
draft_input: EagleDraftInput = batch_result.next_draft_input
self._lazy_init_buf(draft_input)
self.topk_p_buf[intv] = draft_input.topk_p
self.topk_index_buf[intv] = draft_input.topk_index
self.hidden_states_buf[intv] = draft_input.hidden_states
self.verified_id_buf[intv] = draft_input.verified_id
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
self.store_to_map_for_new_batch(future_indices, draft_input)
else:
intv = future_indices.interval
self.token_ids_buf[intv] = batch_result.next_token_ids

def store_to_map_for_new_batch(
self, future_indices: FutureIndices, draft_input: EagleDraftInput
):
intv = future_indices.interval
# idle indices do not need store info
if self.is_empty_slice(intv):
return
self._lazy_init_buf(draft_input)
self.topk_p_buf[intv] = draft_input.topk_p
self.topk_index_buf[intv] = draft_input.topk_index
self.hidden_states_buf[intv] = draft_input.hidden_states
self.verified_id_buf[intv] = draft_input.verified_id
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
23 changes: 13 additions & 10 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,15 @@ def init_disaggregation(self):
self.server_args.disaggregation_transfer_backend
)

if self.draft_worker is None or self.spec_algorithm.is_ngram():
draft_token_to_kv_pool = None
elif self.spec_algorithm.is_eagle() and self.enable_overlap:
draft_token_to_kv_pool = (
self.draft_worker.draft_worker.draft_runner.token_to_kv_pool
)
else:
draft_token_to_kv_pool = self.draft_worker.model_runner.token_to_kv_pool

if (
self.disaggregation_mode == DisaggregationMode.DECODE
): # *2 for the headroom.
Expand Down Expand Up @@ -874,11 +883,7 @@ def init_disaggregation(self):
self.disagg_decode_prealloc_queue = DecodePreallocQueue(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=(
None
if self.draft_worker is None or self.spec_algorithm.is_ngram()
else self.draft_worker.model_runner.token_to_kv_pool
),
draft_token_to_kv_pool=draft_token_to_kv_pool,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
scheduler=self,
Expand Down Expand Up @@ -911,11 +916,7 @@ def init_disaggregation(self):

self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=(
None
if self.draft_worker is None or self.spec_algorithm.is_ngram()
else self.draft_worker.model_runner.token_to_kv_pool
),
draft_token_to_kv_pool=draft_token_to_kv_pool,
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
metadata_buffers=self.disagg_metadata_buffers,
tp_rank=self.tp_rank,
Expand All @@ -935,6 +936,8 @@ def init_disaggregation(self):
self.disagg_prefill_inflight_queue: List[Req] = []

def init_overlap(self):
self.future_map = None

if not self.enable_overlap:
return

Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,10 @@ def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
self.spec_info.accept_length = self.spec_info.accept_length[:bs]
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
elif self.forward_mode.is_draft_extend_v2(): # draft extend_v2
bs = bs * self.spec_info.num_tokens_per_batch
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
logits_output.hidden_states = logits_output.hidden_states[:bs]
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2213,7 +2213,7 @@ def _forward_raw(
reinit_attn_backend=reinit_attn_backend,
forward_count=split_forward_count,
)
elif forward_batch.forward_mode.is_extend():
elif forward_batch.forward_mode.is_extend(include_draft_extend_v2=True):
ret = self.forward_extend(
forward_batch,
skip_attn_backend_init=skip_attn_backend_init,
Expand Down
39 changes: 26 additions & 13 deletions python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, eagle_worker: EAGLEWorker):
set_torch_compile_config()

# Graph inputs
with torch.device("cuda"):
with torch.device(model_runner.device):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
Expand Down Expand Up @@ -158,13 +158,30 @@ def can_run(self, forward_batch: ForwardBatch):

return is_bs_supported

def _create_graph(self):
return torch.cuda.CUDAGraph()

def _capture_init(self, run_once_fn):
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once_fn()

def _capture_graph(self, graph, pool, stream, run_once_fn):
with torch.cuda.graph(graph, pool=pool, stream=stream):
out = run_once_fn()
return out

def _replay(self, forward_batch: ForwardBatch):
self.graphs[self.bs].replay()

def capture(self):
CudaGraphRunner.capture(self)

def capture_one_batch_size(
self, num_seqs: int, forward: Callable, stream_idx: int = 0
):
graph = torch.cuda.CUDAGraph()
graph = self._create_graph()
stream = self.stream
num_tokens = num_seqs * self.num_tokens_per_bs

Expand Down Expand Up @@ -285,16 +302,10 @@ def run_once():

self.deepep_adapter.capture(is_extend_in_batch=False)

for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()

run_once()

with torch.cuda.graph(
graph, pool=get_global_graph_memory_pool(), stream=stream
):
out = run_once()
self._capture_init(run_once)
out = self._capture_graph(
graph, get_global_graph_memory_pool(), stream, run_once
)

set_global_graph_memory_pool(graph.pool())
return graph, out
Expand Down Expand Up @@ -362,10 +373,12 @@ def replay(self, forward_batch: ForwardBatch):
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs
)
self.raw_bs = raw_bs
self.bs = bs
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph

# Replay
self.graphs[bs].replay()
self._replay(forward_batch)
out = self.output_buffers[bs]

if bs != raw_bs:
Expand Down
Loading
Loading