Skip to content

Commit 7bc9946

Browse files
lixiaolxroot
authored andcommitted
(beta)support context parallel with deepseekv3.2-DSA
1 parent efc5d8f commit 7bc9946

File tree

17 files changed

+1247
-54
lines changed

17 files changed

+1247
-54
lines changed

docs/advanced_features/server_arguments.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
396396
| `--numa-node` | Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess. | `None` | List[int] |
397397
| `--enable-layerwise-nvtx-marker` | Enable layerwise NVTX profiling annotations for the model. This adds NVTX markers to every layer for detailed per-layer performance analysis with Nsight Systems. | `False` | bool flag (set to enable) |
398398
| `--enable-attn-tp-input-scattered` | Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent. | `False` | bool flag (set to enable) |
399+
| `--enable-nsa-prefill-context-parallel` | Context parallelism used in the long sequence prefill phase of DeepSeek v3.2 | `False` | bool flag (set to enable) |
399400
400401
## Debug tensor dumps
401402
| Argument | Description | Defaults | Options |

docs/basic_usage/deepseek_v32.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,23 @@ The mean accuracy over 8 runs shows 0.797, which matches the number 79.9 in offi
142142
Repeat: 8, mean: 0.797
143143
Scores: ['0.808', '0.798', '0.808', '0.798', '0.783', '0.788', '0.803', '0.793']
144144
```
145+
146+
147+
## DSA long sequence context parallel optimization(experimental)
148+
149+
Accuracy benchmark on long context can be tested on GPQA-diamond dataset with long output tokens and thinking enabled:
150+
151+
Example usage:
152+
```bash
153+
# Launch with EP + DP
154+
python -m sglang.launch_server --model deepseek-ai/DeepSeek-V3.2-Exp --tp 8 --ep 8 --dp 2 --enable-dp-attention --enable-nsa-prefill-context-parallel --max-running-requests 32
155+
```
156+
### Context-parallel Tips
157+
`CP_size` reuses `atten_tp_size`, which is equal to `TP_size` / `DP_size`.
158+
Some features are still not supported at present.
159+
- **Multi-batch prefill**: Currently, only single-request processing is supported during the prefill process.
160+
- **disaggregation**: P/D disaggregation.
161+
- **Cross-machine support**: - Currently only tested on a single machine (TP=8,EP=8).
162+
- **Other Args**: Currently only supports moe_dense_tp_size=1, kv_cache_dtype = "bf16", moe_a2a_backend = "deepep",
163+
- **DP_size**: `CP_size` reuses `atten_tp_size`, which is equal to `TP_size` / `DP_size`. For the cp function to work correctly, `TP_size` must be divisible by `DP_size`, and TP_size / DP_size > 1 (to ensure CP_size > 1).
164+
- **Detailed design reference**: https://github.com/sgl-project/sglang/pull/12065

python/sglang/srt/distributed/device_communicators/pynccl.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,34 @@ def all_gather(
209209
cudaStream_t(stream.cuda_stream),
210210
)
211211

212+
def cp_all_gather_into_tensor(
213+
self,
214+
output_tensor: torch.Tensor,
215+
input_tensor: torch.Tensor,
216+
stream=None,
217+
sizes: Optional[list[int]] = None,
218+
):
219+
"""
220+
Currently, it is mainly used in context parallelism,
221+
primarily leveraging pynccl to implement non-blocking allgather communication.
222+
"""
223+
# nccl communicator created on a specific device
224+
# will only work on tensors on the same device
225+
# otherwise it will cause "illegal memory access"
226+
assert input_tensor.device == self.device, (
227+
f"this nccl communicator is created to work on {self.device}, "
228+
f"but the input tensor is on {input_tensor.device}"
229+
)
230+
stream = self._resolve_stream(stream)
231+
self.nccl.ncclAllGather(
232+
buffer_type(input_tensor.data_ptr()),
233+
buffer_type(output_tensor.data_ptr()),
234+
input_tensor.numel(),
235+
ncclDataTypeEnum.from_torch(input_tensor.dtype),
236+
self.comm,
237+
cudaStream_t(stream.cuda_stream),
238+
)
239+
212240
def reduce_scatter(
213241
self,
214242
output_tensor: torch.Tensor,

python/sglang/srt/distributed/parallel_state.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,27 @@ def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
748748
output, input, group_name=self.unique_name
749749
)
750750

751+
def cp_all_gather_into_tensor_async(
752+
self, output: torch.Tensor, input: torch.Tensor, stream=None
753+
):
754+
"""
755+
Implement an asynchronous `allgather` operation on a specified stream.
756+
(the default `torch.distributed.all_gather_into_tensor` will trigger event synchronization),
757+
eliminating the CPU-side launch-kernel blocking issue caused by synchronization problems.
758+
The specific implementation uses the interface provided by pynccl to remove the synchronization logic of events.
759+
"""
760+
assert (
761+
stream is not None
762+
), f"Invalid params stream ({stream}, Please specify the stream to use when calling cp_all_gather_into_tensor_async.)"
763+
pynccl_comm = self.pynccl_comm
764+
if pynccl_comm is not None:
765+
pynccl_comm.cp_all_gather_into_tensor(output, input, stream=stream)
766+
else:
767+
logger.warning("not all_gather_into_tensor_async")
768+
torch.ops.sglang.reg_all_gather_into_tensor(
769+
output, input, group_name=self.unique_name
770+
)
771+
751772
def all_gather(
752773
self,
753774
input_: torch.Tensor,

python/sglang/srt/layers/attention/nsa/nsa_indexer.py

Lines changed: 221 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import TYPE_CHECKING, Any, Dict, Optional
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
55

66
import torch
77
from einops import rearrange
@@ -16,9 +16,18 @@
1616
except ImportError as e:
1717
deep_gemm = e
1818

19+
1920
from sglang.srt.layers import deep_gemm_wrapper
20-
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM
21-
from sglang.srt.layers.dp_attention import get_attention_tp_group
21+
from sglang.srt.layers.attention.nsa.utils import (
22+
NSA_DUAL_STREAM,
23+
cp_all_gather_rerange_output,
24+
is_nsa_enable_prefill_cp,
25+
)
26+
from sglang.srt.layers.dp_attention import (
27+
get_attention_tp_group,
28+
get_attention_tp_rank,
29+
get_attention_tp_size,
30+
)
2231
from sglang.srt.layers.linear import ReplicatedLinear
2332
from sglang.srt.layers.quantization.base_config import QuantizationConfig
2433
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
@@ -112,6 +121,13 @@ def __init__(
112121
self.layer_id = layer_id
113122
self.alt_stream = alt_stream
114123
self.fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
124+
self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp()
125+
if self.nsa_enable_prefill_cp:
126+
self.cp_size = get_attention_tp_size()
127+
self.cp_rank = get_attention_tp_rank()
128+
else:
129+
self.cp_size = None
130+
self.cp_rank = None
115131
if is_cuda():
116132
self.sm_count = deep_gemm.get_num_sms()
117133
self.half_device_sm_count = ceil_align(self.sm_count // 2, 8)
@@ -171,6 +187,7 @@ def _get_q_k_bf16(
171187
x: torch.Tensor,
172188
positions: torch.Tensor,
173189
enable_dual_stream: bool,
190+
forward_batch: ForwardBatch,
174191
):
175192
weights = None
176193
if enable_dual_stream:
@@ -228,6 +245,15 @@ def _get_q_k_bf16(
228245
query[..., : self.rope_head_dim] = q_rope
229246
key[..., : self.rope_head_dim] = k_rope
230247

248+
# allgather+rerrange
249+
if forward_batch.nsa_cp_metadata is not None and self.nsa_enable_prefill_cp:
250+
key = cp_all_gather_rerange_output(
251+
key.contiguous(),
252+
self.cp_size,
253+
forward_batch,
254+
torch.cuda.current_stream(),
255+
)
256+
231257
if enable_dual_stream:
232258
current_stream = torch.cuda.current_stream()
233259
self.alt_stream.wait_stream(current_stream)
@@ -469,6 +495,153 @@ def _forward_cuda_k_only(
469495
)
470496
return metadata.topk_transform(dummy_logits, self.index_topk)
471497

498+
def _get_topk_ragged_with_cp(
499+
self,
500+
forward_batch: ForwardBatch,
501+
layer_id: int,
502+
q_fp8: torch.Tensor,
503+
weights: torch.Tensor,
504+
metadata: BaseIndexerMetadata,
505+
kv_len: int,
506+
actual_seq_q: int,
507+
cp_index: List[Tuple[int, int, int]] = None,
508+
) -> torch.Tensor:
509+
if TYPE_CHECKING:
510+
assert isinstance(forward_batch.token_to_kv_pool, NSATokenToKVPool)
511+
512+
page_size = forward_batch.token_to_kv_pool.page_size
513+
assert page_size == 64, "only support page size 64"
514+
assert len(weights.shape) == 3
515+
weights = weights.squeeze(-1)
516+
k_fp8_list = []
517+
k_scale_list = []
518+
ks_list = []
519+
ke_offset_list = []
520+
offset = 0
521+
actual_seq_q_list = []
522+
batch_idx_list = []
523+
524+
block_tables = metadata.get_page_table_64()
525+
526+
assert (
527+
forward_batch.seq_lens_cpu is not None
528+
and forward_batch.extend_seq_lens_cpu is not None
529+
)
530+
if cp_index is not None:
531+
# TODO Multi-batch support has accuracy issues
532+
for batch_idx, start_seq_position, end_seq_position in cp_index:
533+
pre_chunk_offset = (
534+
forward_batch.seq_lens_cpu[batch_idx].item()
535+
- forward_batch.extend_seq_lens_cpu[batch_idx]
536+
)
537+
start_seq_position += pre_chunk_offset
538+
end_seq_position += pre_chunk_offset
539+
if offset == 0 and batch_idx != 0:
540+
offset += forward_batch.extend_seq_lens_cpu[batch_idx - 1]
541+
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
542+
layer_id,
543+
end_seq_position,
544+
block_tables[batch_idx],
545+
)
546+
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
547+
layer_id,
548+
end_seq_position,
549+
block_tables[batch_idx],
550+
)
551+
552+
extend_seq_len = end_seq_position - start_seq_position
553+
ks = torch.full(
554+
(extend_seq_len,), offset, dtype=torch.int32, device="cuda"
555+
)
556+
k_fp8_list.append(k_fp8)
557+
k_scale_list.append(k_scale)
558+
ks_list.append(ks)
559+
ke_offset = torch.arange(
560+
start_seq_position + 1,
561+
end_seq_position + 1,
562+
dtype=torch.int32,
563+
device="cuda",
564+
)
565+
ke_offset_list.append(ke_offset)
566+
actual_seq_q = torch.tensor(
567+
[extend_seq_len], dtype=torch.int32, device="cuda"
568+
)
569+
actual_seq_q_list.append(actual_seq_q)
570+
batch_idx_list.append(batch_idx)
571+
572+
k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn)
573+
k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1)
574+
kv_fp8 = (k_fp8, k_scale)
575+
ks = torch.cat(ks_list, dim=0)
576+
ke_offset = torch.cat(ke_offset_list, dim=0)
577+
ke = ks + ke_offset
578+
actual_seq_q = torch.cat(actual_seq_q_list, dim=0)
579+
logits = deep_gemm.fp8_mqa_logits(
580+
q_fp8,
581+
kv_fp8,
582+
weights,
583+
ks,
584+
ke,
585+
clean_logits=False,
586+
)
587+
topk_result = metadata.topk_transform(
588+
logits,
589+
self.index_topk,
590+
ks=ks,
591+
cu_seqlens_q=actual_seq_q,
592+
ke_offset=ke_offset,
593+
batch_idx_list=batch_idx_list,
594+
)
595+
else:
596+
kv_len = (
597+
forward_batch.seq_lens_cpu[0].item()
598+
- forward_batch.extend_seq_lens_cpu[0]
599+
+ kv_len
600+
)
601+
k_fp8 = forward_batch.token_to_kv_pool.get_index_k_continuous(
602+
layer_id,
603+
kv_len,
604+
block_tables[0],
605+
)
606+
k_scale = forward_batch.token_to_kv_pool.get_index_k_scale_continuous(
607+
layer_id,
608+
kv_len,
609+
block_tables[0],
610+
)
611+
612+
k_fp8 = k_fp8.view(torch.float8_e4m3fn)
613+
k_scale = k_scale.view(torch.float32).squeeze(-1)
614+
kv_fp8 = (k_fp8, k_scale)
615+
ks = torch.full((actual_seq_q,), offset, dtype=torch.int32, device="cuda")
616+
ke_offset = torch.arange(
617+
(kv_len - actual_seq_q) + 1,
618+
kv_len + 1,
619+
dtype=torch.int32,
620+
device="cuda",
621+
)
622+
ke = ks + ke_offset
623+
624+
logits = deep_gemm.fp8_mqa_logits(
625+
q_fp8,
626+
kv_fp8,
627+
weights,
628+
ks,
629+
ke,
630+
clean_logits=False,
631+
)
632+
actual_seq_q = torch.tensor([actual_seq_q], dtype=torch.int32).to(
633+
device="cuda", non_blocking=True
634+
)
635+
topk_result = metadata.topk_transform(
636+
logits,
637+
self.index_topk,
638+
ks=ks,
639+
cu_seqlens_q=actual_seq_q,
640+
ke_offset=ke_offset,
641+
)
642+
643+
return topk_result
644+
472645
def forward_indexer(
473646
self,
474647
q_fp8: torch.Tensor,
@@ -594,7 +767,7 @@ def forward_cuda(
594767
skip_logits_computation = max_kv_len <= self.index_topk
595768

596769
# Optimization: fast path when skipping topk computation
597-
if skip_logits_computation:
770+
if skip_logits_computation and (not self.nsa_enable_prefill_cp):
598771
return self._forward_cuda_k_only(
599772
x,
600773
positions,
@@ -607,7 +780,7 @@ def forward_cuda(
607780
)
608781

609782
query, key, weights = self._get_q_k_bf16(
610-
q_lora, x, positions, enable_dual_stream
783+
q_lora, x, positions, enable_dual_stream, forward_batch=forward_batch
611784
)
612785

613786
if enable_dual_stream:
@@ -660,9 +833,49 @@ def forward_cuda(
660833
forward_batch, layer_id, q_fp8, weights, metadata
661834
)
662835
else:
663-
topk_result = self._get_topk_ragged(
664-
forward_batch, layer_id, q_fp8, weights, metadata
665-
)
836+
if (
837+
forward_batch.nsa_cp_metadata is not None
838+
and self.nsa_enable_prefill_cp
839+
):
840+
kv_len_prev = forward_batch.nsa_cp_metadata.kv_len_prev
841+
kv_len_next = forward_batch.nsa_cp_metadata.kv_len_next
842+
actual_seq_q_prev = forward_batch.nsa_cp_metadata.actual_seq_q_prev
843+
actual_seq_q_next = forward_batch.nsa_cp_metadata.actual_seq_q_next
844+
845+
# TODO support mutil-batch
846+
# cp_batch_seq_index_prev = forward_batch.nsa_cp_metadata["cp_batch_seq_index_prev"]
847+
# cp_batch_seq_index_next = forward_batch.nsa_cp_metadata["cp_batch_seq_index_next"]
848+
# TODO prev, next, combined into a single call
849+
q_fp8_prev, q_fp8_next = torch.split(
850+
q_fp8, (q_fp8.shape[0] + 1) // 2, dim=0
851+
)
852+
weights_prev, weights_next = torch.split(
853+
weights, (weights.shape[0] + 1) // 2, dim=0
854+
)
855+
topk_result_prev = self._get_topk_ragged_with_cp(
856+
forward_batch,
857+
layer_id,
858+
q_fp8_prev,
859+
weights_prev,
860+
metadata,
861+
kv_len_prev,
862+
actual_seq_q_prev,
863+
)
864+
865+
topk_result_next = self._get_topk_ragged_with_cp(
866+
forward_batch,
867+
layer_id,
868+
q_fp8_next,
869+
weights_next,
870+
metadata,
871+
kv_len_next,
872+
actual_seq_q_next,
873+
)
874+
return torch.cat([topk_result_prev, topk_result_next], dim=0)
875+
else:
876+
topk_result = self._get_topk_ragged(
877+
forward_batch, layer_id, q_fp8, weights, metadata
878+
)
666879
else:
667880
topk_result = self.forward_indexer(
668881
q_fp8.contiguous(),

0 commit comments

Comments
 (0)