1111# See the License for the specific language governing permissions and
1212# limitations under the License.
1313# ==============================================================================
14-
14+ import logging
15+ from contextlib import contextmanager
1516from dataclasses import dataclass
1617from enum import Enum , auto
1718from functools import partial
18- from typing import Dict , List , Optional
19+ from typing import Callable , Dict , List , Optional , Tuple
1920
2021import torch
2122
2223from sglang .srt .distributed import (
24+ get_tensor_model_parallel_rank ,
2325 get_tensor_model_parallel_world_size ,
2426 get_tp_group ,
2527 tensor_model_parallel_all_reduce ,
5961 prepare_weight_cache ,
6062)
6163
64+ _is_cuda = is_cuda ()
6265_is_flashinfer_available = is_flashinfer_available ()
63- _is_sm90_supported = is_cuda () and is_sm90_supported ()
64- _is_sm100_supported = is_cuda () and is_sm100_supported ()
66+ _is_sm90_supported = _is_cuda and is_sm90_supported ()
67+ _is_sm100_supported = _is_cuda and is_sm100_supported ()
6568_use_aiter = get_bool_env_var ("SGLANG_USE_AITER" ) and is_hip ()
6669_is_gfx95_supported = is_gfx95_supported ()
6770
@@ -92,6 +95,119 @@ def model_input_output():
9295 return ScatterMode .TP_ATTN_FULL
9396
9497
98+ class AttentionInputs :
99+
100+ def __init__ (
101+ self ,
102+ hidden_states : torch .Tensor ,
103+ forward_batch : ForwardBatch ,
104+ qkv_latent_func : Callable ,
105+ ):
106+ self .hidden_states_local = hidden_states
107+ self .forward_batch = forward_batch
108+ self .qkv_latent_func = qkv_latent_func
109+ self .hidden_states_ = None
110+ self .qkv_latent_ = None
111+
112+ def tp_all_gather_hidden_states (self , hidden_states , forward_batch ):
113+ total_tokens = forward_batch .input_ids .shape [0 ]
114+ output = hidden_states .new_empty ((total_tokens , hidden_states .shape [- 1 ]))
115+ get_tp_group ().all_gather_into_tensor (output , hidden_states )
116+ return output
117+
118+ def fetch_qkv_latent (self ):
119+ if self .qkv_latent_ is not None :
120+ return self .qkv_latent_
121+ assert self .qkv_latent_func is not None
122+ self .qkv_latent_ = self .qkv_latent_func (
123+ self .hidden_states_local , self .forward_batch
124+ )
125+ if get_attn_tp_context ().input_scattered :
126+ self .qkv_latent_ = self .tp_all_gather_hidden_states (
127+ self .qkv_latent_ , self .forward_batch
128+ )
129+ return self .qkv_latent_
130+
131+ def fetch_hidden_states (self ):
132+ if self .hidden_states_ is not None :
133+ return self .hidden_states_
134+ self .hidden_states_ = self .hidden_states_local
135+ if get_attn_tp_context ().input_scattered :
136+ self .hidden_states_ = self .tp_all_gather_hidden_states (
137+ self .hidden_states_ , self .forward_batch
138+ )
139+ return self .hidden_states_
140+
141+
142+ class AttnTpContext :
143+ def __init__ (self ):
144+ self .allow_input_scattered = False
145+ self .input_scattered_ = False
146+ self .attn_inputs_ : Optional [AttentionInputs ] = None
147+
148+ def init_context (self , q_lora_rank , is_nsa ):
149+ self .allow_input_scattered = (
150+ get_global_server_args ().enable_attn_tp_input_scattered
151+ and _is_cuda
152+ and q_lora_rank is not None
153+ and not is_nsa
154+ and get_tensor_model_parallel_world_size () > 1
155+ and not is_dp_attention_enabled ()
156+ and get_moe_a2a_backend ().is_none ()
157+ and not enable_moe_dense_fully_dp ()
158+ and not get_global_server_args ().enable_piecewise_cuda_graph
159+ and get_global_server_args ().speculative_algorithm != "EAGLE3"
160+ )
161+ if get_global_server_args ().enable_attn_tp_input_scattered :
162+ if not self .allow_input_scattered :
163+ logging .info (
164+ "attn_tp_input_scattered is not enabled while other conditions are not met"
165+ )
166+ else :
167+ logging .info ("attn_tp_input_scattered is enabled" )
168+
169+ def use_input_scattered (self , forward_batch : ForwardBatch ):
170+ return (
171+ self .allow_input_scattered
172+ and forward_batch .forward_mode .is_extend ()
173+ and not forward_batch .forward_mode .is_target_verify ()
174+ and not forward_batch .forward_mode .is_draft_extend ()
175+ and forward_batch .input_ids is not None
176+ and not forward_batch .can_run_tbo
177+ )
178+
179+ @property
180+ def input_scattered (self ):
181+ return self .input_scattered_
182+
183+ def set_attn_inputs (self , attn_inputs : AttentionInputs ):
184+ self .attn_inputs_ = attn_inputs
185+
186+ def fetch_qkv_latent (self ):
187+ assert self .attn_inputs_ is not None
188+ return self .attn_inputs_ .fetch_qkv_latent ()
189+
190+ def fetch_hidden_states (self ):
191+ assert self .attn_inputs_ is not None
192+ return self .attn_inputs_ .fetch_hidden_states ()
193+
194+ @contextmanager
195+ def maybe_input_scattered (self , forward_batch : ForwardBatch ):
196+ flag = self .use_input_scattered (forward_batch )
197+ old_flag = self .input_scattered
198+ self .input_scattered_ = flag
199+ yield
200+ self .input_scattered_ = old_flag
201+ self .attn_inputs_ = None
202+
203+
204+ ATTN_TP_CONTEXT = AttnTpContext ()
205+
206+
207+ def get_attn_tp_context ():
208+ return ATTN_TP_CONTEXT
209+
210+
95211@dataclass
96212class _LayerModeComputationContext :
97213 num_layers : int
@@ -188,12 +304,14 @@ def __init__(
188304 # Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
189305 allow_reduce_scatter : bool = False ,
190306 is_last_layer : bool = False ,
307+ qkv_latent_func : Optional [Callable ] = None ,
191308 ):
192309 self .layer_scatter_modes = layer_scatter_modes
193310 self .input_layernorm = input_layernorm
194311 self .post_attention_layernorm = post_attention_layernorm
195312 self .allow_reduce_scatter = allow_reduce_scatter
196313 self .is_last_layer = is_last_layer
314+ self .qkv_latent_func = qkv_latent_func
197315
198316 self ._context = CommunicateContext .init_new ()
199317 self ._communicate_simple_fn = CommunicateSimpleFn .get_fn (
@@ -252,6 +370,11 @@ def prepare_attn(
252370 forward_batch : ForwardBatch ,
253371 quant_format : str = "" ,
254372 ):
373+ if get_attn_tp_context ().input_scattered :
374+ hidden_states , residual = self ._tp_reduce_scatter (
375+ hidden_states ,
376+ residual ,
377+ )
255378 if hidden_states .shape [0 ] == 0 :
256379 residual = hidden_states
257380 else :
@@ -335,9 +458,32 @@ def prepare_attn(
335458 forward_batch = forward_batch ,
336459 context = self ._context ,
337460 )
338-
461+ if self .qkv_latent_func is not None :
462+ attn_inputs = AttentionInputs (
463+ hidden_states , forward_batch , self .qkv_latent_func
464+ )
465+ get_attn_tp_context ().set_attn_inputs (attn_inputs )
339466 return hidden_states , residual
340467
468+ def _tp_reduce_scatter (
469+ self ,
470+ hidden_states : torch .Tensor ,
471+ residual : torch .Tensor ,
472+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
473+ if hidden_states .shape [0 ] == 0 :
474+ return hidden_states , hidden_states
475+ assert (
476+ hidden_states .shape [0 ] % self ._context .tp_size == 0
477+ ), f"Expected total tokens { hidden_states .shape [0 ]} % tp_size { self ._context .tp_size } to be 0"
478+ local_tokens = hidden_states .shape [0 ] // self ._context .tp_size
479+ output = hidden_states .new_empty (local_tokens , * hidden_states .shape [1 :])
480+ get_tp_group ().reduce_scatter_tensor (output , hidden_states )
481+ if residual is not None :
482+ residual = residual .tensor_split (self ._context .tp_size )[
483+ self ._context .tp_rank
484+ ]
485+ return output , residual
486+
341487 def prepare_mlp (
342488 self ,
343489 hidden_states : torch .Tensor ,
@@ -371,12 +517,17 @@ def postprocess_layer(
371517 )
372518
373519 def should_use_reduce_scatter (self , forward_batch : ForwardBatch ):
374- return (
375- self .allow_reduce_scatter
376- and self ._communicate_summable_tensor_pair_fn
520+ if not self .allow_reduce_scatter :
521+ return False
522+ if (
523+ self ._communicate_summable_tensor_pair_fn
377524 is CommunicateSummableTensorPairFn ._scatter_hidden_states
378525 and forward_batch .dp_padding_mode .is_max_len ()
379- )
526+ ):
527+ return True
528+ if get_attn_tp_context ().input_scattered and not self .is_last_layer :
529+ return True
530+ return False
380531
381532 def should_fuse_mlp_allreduce_with_next_layer (
382533 self , forward_batch : ForwardBatch
@@ -388,6 +539,9 @@ def should_fuse_mlp_allreduce_with_next_layer(
388539 ):
389540 return False
390541
542+ if get_attn_tp_context ().input_scattered :
543+ return False
544+
391545 batch_size = (
392546 forward_batch .input_ids .shape [0 ]
393547 if hasattr (forward_batch , "input_ids" )
@@ -422,6 +576,7 @@ class CommunicateContext:
422576 attn_dp_size : int
423577 tp_size : int
424578 cache = None
579+ tp_rank : int
425580
426581 def is_same_group_size (self , a : ScatterMode , b : ScatterMode ):
427582 return self .process_group_sizes [a ] == self .process_group_sizes [b ]
@@ -432,6 +587,7 @@ def init_new(cls):
432587 attn_tp_size = get_attention_tp_size ()
433588 attn_dp_size = get_attention_dp_size ()
434589 tp_size = get_tensor_model_parallel_world_size ()
590+ tp_rank = get_tensor_model_parallel_rank ()
435591 process_group_sizes = {
436592 ScatterMode .SCATTERED : 1 ,
437593 ScatterMode .TP_ATTN_FULL : attn_tp_size ,
@@ -444,6 +600,7 @@ def init_new(cls):
444600 attn_tp_size = attn_tp_size ,
445601 attn_dp_size = attn_dp_size ,
446602 tp_size = tp_size ,
603+ tp_rank = tp_rank ,
447604 )
448605
449606
@@ -566,6 +723,14 @@ def _gather_hidden_states_and_residual(
566723 * ,
567724 residual_input_mode ,
568725 ):
726+ if get_attn_tp_context ().input_scattered :
727+ return CommunicateWithAllReduceAndLayerNormFn ._tp_all_reduce_with_scattered_residual (
728+ hidden_states ,
729+ residual ,
730+ layernorm ,
731+ context ,
732+ )
733+
569734 if residual_input_mode == ScatterMode .SCATTERED and context .attn_tp_size > 1 :
570735 residual , local_residual = (
571736 get_local_dp_buffer (),
@@ -637,6 +802,22 @@ def _scatter_hidden_states_and_residual(
637802 hidden_states , residual = layernorm (hidden_states , residual )
638803 return hidden_states , residual
639804
805+ @staticmethod
806+ def _tp_all_reduce_with_scattered_residual (
807+ hidden_states : torch .Tensor ,
808+ residual : torch .Tensor ,
809+ layernorm : torch .nn .Module ,
810+ context : CommunicateContext ,
811+ ):
812+ if hidden_states .shape [0 ] == 0 :
813+ return hidden_states , hidden_states
814+
815+ scattered_states = hidden_states .tensor_split (context .tp_size )[context .tp_rank ]
816+ scattered_states += residual
817+ residual = tensor_model_parallel_all_reduce (hidden_states )
818+ hidden_states = layernorm (residual )
819+ return hidden_states , residual
820+
640821
641822class CommunicateSummableTensorPairFn :
642823 """It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
0 commit comments