11from __future__ import annotations
22
33from abc import ABC , abstractmethod
4- from typing import TYPE_CHECKING , Any , Dict , Optional
4+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple
55
66import torch
77from einops import rearrange
1616 except ImportError as e :
1717 deep_gemm = e
1818
19+
1920from sglang .srt .layers import deep_gemm_wrapper
20- from sglang .srt .layers .attention .nsa .utils import NSA_DUAL_STREAM
21- from sglang .srt .layers .dp_attention import get_attention_tp_group
21+ from sglang .srt .layers .attention .nsa .utils import (
22+ NSA_DUAL_STREAM ,
23+ cp_all_gather_rerange_output ,
24+ is_nsa_enable_prefill_cp ,
25+ )
26+ from sglang .srt .layers .dp_attention import (
27+ get_attention_tp_group ,
28+ get_attention_tp_rank ,
29+ get_attention_tp_size ,
30+ )
2231from sglang .srt .layers .linear import ReplicatedLinear
2332from sglang .srt .layers .quantization .base_config import QuantizationConfig
2433from sglang .srt .layers .rotary_embedding import get_rope_wrapper
@@ -112,6 +121,13 @@ def __init__(
112121 self .layer_id = layer_id
113122 self .alt_stream = alt_stream
114123 self .fuse_wk_and_weights_proj = fuse_wk_and_weights_proj
124+ self .nsa_enable_prefill_cp = is_nsa_enable_prefill_cp ()
125+ if self .nsa_enable_prefill_cp :
126+ self .cp_size = get_attention_tp_size ()
127+ self .cp_rank = get_attention_tp_rank ()
128+ else :
129+ self .cp_size = None
130+ self .cp_rank = None
115131 if is_cuda ():
116132 self .sm_count = deep_gemm .get_num_sms ()
117133 self .half_device_sm_count = ceil_align (self .sm_count // 2 , 8 )
@@ -171,6 +187,7 @@ def _get_q_k_bf16(
171187 x : torch .Tensor ,
172188 positions : torch .Tensor ,
173189 enable_dual_stream : bool ,
190+ forward_batch : ForwardBatch ,
174191 ):
175192 weights = None
176193 if enable_dual_stream :
@@ -228,6 +245,15 @@ def _get_q_k_bf16(
228245 query [..., : self .rope_head_dim ] = q_rope
229246 key [..., : self .rope_head_dim ] = k_rope
230247
248+ # allgather+rerrange
249+ if forward_batch .nsa_cp_metadata is not None and self .nsa_enable_prefill_cp :
250+ key = cp_all_gather_rerange_output (
251+ key .contiguous (),
252+ self .cp_size ,
253+ forward_batch ,
254+ torch .cuda .current_stream (),
255+ )
256+
231257 if enable_dual_stream :
232258 current_stream = torch .cuda .current_stream ()
233259 self .alt_stream .wait_stream (current_stream )
@@ -469,6 +495,153 @@ def _forward_cuda_k_only(
469495 )
470496 return metadata .topk_transform (dummy_logits , self .index_topk )
471497
498+ def _get_topk_ragged_with_cp (
499+ self ,
500+ forward_batch : ForwardBatch ,
501+ layer_id : int ,
502+ q_fp8 : torch .Tensor ,
503+ weights : torch .Tensor ,
504+ metadata : BaseIndexerMetadata ,
505+ kv_len : int ,
506+ actual_seq_q : int ,
507+ cp_index : List [Tuple [int , int , int ]] = None ,
508+ ) -> torch .Tensor :
509+ if TYPE_CHECKING :
510+ assert isinstance (forward_batch .token_to_kv_pool , NSATokenToKVPool )
511+
512+ page_size = forward_batch .token_to_kv_pool .page_size
513+ assert page_size == 64 , "only support page size 64"
514+ assert len (weights .shape ) == 3
515+ weights = weights .squeeze (- 1 )
516+ k_fp8_list = []
517+ k_scale_list = []
518+ ks_list = []
519+ ke_offset_list = []
520+ offset = 0
521+ actual_seq_q_list = []
522+ batch_idx_list = []
523+
524+ block_tables = metadata .get_page_table_64 ()
525+
526+ assert (
527+ forward_batch .seq_lens_cpu is not None
528+ and forward_batch .extend_seq_lens_cpu is not None
529+ )
530+ if cp_index is not None :
531+ # TODO Multi-batch support has accuracy issues
532+ for batch_idx , start_seq_position , end_seq_position in cp_index :
533+ pre_chunk_offset = (
534+ forward_batch .seq_lens_cpu [batch_idx ].item ()
535+ - forward_batch .extend_seq_lens_cpu [batch_idx ]
536+ )
537+ start_seq_position += pre_chunk_offset
538+ end_seq_position += pre_chunk_offset
539+ if offset == 0 and batch_idx != 0 :
540+ offset += forward_batch .extend_seq_lens_cpu [batch_idx - 1 ]
541+ k_fp8 = forward_batch .token_to_kv_pool .get_index_k_continuous (
542+ layer_id ,
543+ end_seq_position ,
544+ block_tables [batch_idx ],
545+ )
546+ k_scale = forward_batch .token_to_kv_pool .get_index_k_scale_continuous (
547+ layer_id ,
548+ end_seq_position ,
549+ block_tables [batch_idx ],
550+ )
551+
552+ extend_seq_len = end_seq_position - start_seq_position
553+ ks = torch .full (
554+ (extend_seq_len ,), offset , dtype = torch .int32 , device = "cuda"
555+ )
556+ k_fp8_list .append (k_fp8 )
557+ k_scale_list .append (k_scale )
558+ ks_list .append (ks )
559+ ke_offset = torch .arange (
560+ start_seq_position + 1 ,
561+ end_seq_position + 1 ,
562+ dtype = torch .int32 ,
563+ device = "cuda" ,
564+ )
565+ ke_offset_list .append (ke_offset )
566+ actual_seq_q = torch .tensor (
567+ [extend_seq_len ], dtype = torch .int32 , device = "cuda"
568+ )
569+ actual_seq_q_list .append (actual_seq_q )
570+ batch_idx_list .append (batch_idx )
571+
572+ k_fp8 = torch .cat (k_fp8_list , dim = 0 ).view (torch .float8_e4m3fn )
573+ k_scale = torch .cat (k_scale_list , dim = 0 ).view (torch .float32 ).squeeze (- 1 )
574+ kv_fp8 = (k_fp8 , k_scale )
575+ ks = torch .cat (ks_list , dim = 0 )
576+ ke_offset = torch .cat (ke_offset_list , dim = 0 )
577+ ke = ks + ke_offset
578+ actual_seq_q = torch .cat (actual_seq_q_list , dim = 0 )
579+ logits = deep_gemm .fp8_mqa_logits (
580+ q_fp8 ,
581+ kv_fp8 ,
582+ weights ,
583+ ks ,
584+ ke ,
585+ clean_logits = False ,
586+ )
587+ topk_result = metadata .topk_transform (
588+ logits ,
589+ self .index_topk ,
590+ ks = ks ,
591+ cu_seqlens_q = actual_seq_q ,
592+ ke_offset = ke_offset ,
593+ batch_idx_list = batch_idx_list ,
594+ )
595+ else :
596+ kv_len = (
597+ forward_batch .seq_lens_cpu [0 ].item ()
598+ - forward_batch .extend_seq_lens_cpu [0 ]
599+ + kv_len
600+ )
601+ k_fp8 = forward_batch .token_to_kv_pool .get_index_k_continuous (
602+ layer_id ,
603+ kv_len ,
604+ block_tables [0 ],
605+ )
606+ k_scale = forward_batch .token_to_kv_pool .get_index_k_scale_continuous (
607+ layer_id ,
608+ kv_len ,
609+ block_tables [0 ],
610+ )
611+
612+ k_fp8 = k_fp8 .view (torch .float8_e4m3fn )
613+ k_scale = k_scale .view (torch .float32 ).squeeze (- 1 )
614+ kv_fp8 = (k_fp8 , k_scale )
615+ ks = torch .full ((actual_seq_q ,), offset , dtype = torch .int32 , device = "cuda" )
616+ ke_offset = torch .arange (
617+ (kv_len - actual_seq_q ) + 1 ,
618+ kv_len + 1 ,
619+ dtype = torch .int32 ,
620+ device = "cuda" ,
621+ )
622+ ke = ks + ke_offset
623+
624+ logits = deep_gemm .fp8_mqa_logits (
625+ q_fp8 ,
626+ kv_fp8 ,
627+ weights ,
628+ ks ,
629+ ke ,
630+ clean_logits = False ,
631+ )
632+ actual_seq_q = torch .tensor ([actual_seq_q ], dtype = torch .int32 ).to (
633+ device = "cuda" , non_blocking = True
634+ )
635+ topk_result = metadata .topk_transform (
636+ logits ,
637+ self .index_topk ,
638+ ks = ks ,
639+ cu_seqlens_q = actual_seq_q ,
640+ ke_offset = ke_offset ,
641+ )
642+
643+ return topk_result
644+
472645 def forward_indexer (
473646 self ,
474647 q_fp8 : torch .Tensor ,
@@ -594,7 +767,7 @@ def forward_cuda(
594767 skip_logits_computation = max_kv_len <= self .index_topk
595768
596769 # Optimization: fast path when skipping topk computation
597- if skip_logits_computation :
770+ if skip_logits_computation and ( not self . nsa_enable_prefill_cp ) :
598771 return self ._forward_cuda_k_only (
599772 x ,
600773 positions ,
@@ -607,7 +780,7 @@ def forward_cuda(
607780 )
608781
609782 query , key , weights = self ._get_q_k_bf16 (
610- q_lora , x , positions , enable_dual_stream
783+ q_lora , x , positions , enable_dual_stream , forward_batch = forward_batch
611784 )
612785
613786 if enable_dual_stream :
@@ -660,9 +833,49 @@ def forward_cuda(
660833 forward_batch , layer_id , q_fp8 , weights , metadata
661834 )
662835 else :
663- topk_result = self ._get_topk_ragged (
664- forward_batch , layer_id , q_fp8 , weights , metadata
665- )
836+ if (
837+ forward_batch .nsa_cp_metadata is not None
838+ and self .nsa_enable_prefill_cp
839+ ):
840+ kv_len_prev = forward_batch .nsa_cp_metadata .kv_len_prev
841+ kv_len_next = forward_batch .nsa_cp_metadata .kv_len_next
842+ actual_seq_q_prev = forward_batch .nsa_cp_metadata .actual_seq_q_prev
843+ actual_seq_q_next = forward_batch .nsa_cp_metadata .actual_seq_q_next
844+
845+ # TODO support mutil-batch
846+ # cp_batch_seq_index_prev = forward_batch.nsa_cp_metadata["cp_batch_seq_index_prev"]
847+ # cp_batch_seq_index_next = forward_batch.nsa_cp_metadata["cp_batch_seq_index_next"]
848+ # TODO prev, next, combined into a single call
849+ q_fp8_prev , q_fp8_next = torch .split (
850+ q_fp8 , (q_fp8 .shape [0 ] + 1 ) // 2 , dim = 0
851+ )
852+ weights_prev , weights_next = torch .split (
853+ weights , (weights .shape [0 ] + 1 ) // 2 , dim = 0
854+ )
855+ topk_result_prev = self ._get_topk_ragged_with_cp (
856+ forward_batch ,
857+ layer_id ,
858+ q_fp8_prev ,
859+ weights_prev ,
860+ metadata ,
861+ kv_len_prev ,
862+ actual_seq_q_prev ,
863+ )
864+
865+ topk_result_next = self ._get_topk_ragged_with_cp (
866+ forward_batch ,
867+ layer_id ,
868+ q_fp8_next ,
869+ weights_next ,
870+ metadata ,
871+ kv_len_next ,
872+ actual_seq_q_next ,
873+ )
874+ return torch .cat ([topk_result_prev , topk_result_next ], dim = 0 )
875+ else :
876+ topk_result = self ._get_topk_ragged (
877+ forward_batch , layer_id , q_fp8 , weights , metadata
878+ )
666879 else :
667880 topk_result = self .forward_indexer (
668881 q_fp8 .contiguous (),
0 commit comments