Skip to content

Commit d91b16e

Browse files
authored
Opt tp: tp attn support tp reduce scattered input (#10568)
1 parent 4a10e37 commit d91b16e

File tree

7 files changed

+275
-36
lines changed

7 files changed

+275
-36
lines changed

docs/advanced_features/server_arguments.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
394394
| `--enable-return-hidden-states` | Enable returning hidden states with responses. | `False` | bool flag (set to enable) |
395395
| `--scheduler-recv-interval` | The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this. | `1` | Type: int |
396396
| `--numa-node` | Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess. | `None` | List[int] |
397+
| `--enable-attn-tp-input-scattered` | Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent. | `False` | bool flag (set to enable) |
397398
398399
## Debug tensor dumps
399400
| Argument | Description | Defaults | Options |

python/sglang/srt/layers/communicator.py

Lines changed: 190 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313
# ==============================================================================
14-
14+
import logging
15+
from contextlib import contextmanager
1516
from dataclasses import dataclass
1617
from enum import Enum, auto
1718
from functools import partial
18-
from typing import Dict, List, Optional
19+
from typing import Callable, Dict, List, Optional, Tuple
1920

2021
import torch
2122

2223
from sglang.srt.distributed import (
24+
get_tensor_model_parallel_rank,
2325
get_tensor_model_parallel_world_size,
2426
get_tp_group,
2527
tensor_model_parallel_all_reduce,
@@ -59,9 +61,10 @@
5961
prepare_weight_cache,
6062
)
6163

64+
_is_cuda = is_cuda()
6265
_is_flashinfer_available = is_flashinfer_available()
63-
_is_sm90_supported = is_cuda() and is_sm90_supported()
64-
_is_sm100_supported = is_cuda() and is_sm100_supported()
66+
_is_sm90_supported = _is_cuda and is_sm90_supported()
67+
_is_sm100_supported = _is_cuda and is_sm100_supported()
6568
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
6669
_is_gfx95_supported = is_gfx95_supported()
6770

@@ -92,6 +95,119 @@ def model_input_output():
9295
return ScatterMode.TP_ATTN_FULL
9396

9497

98+
class AttentionInputs:
99+
100+
def __init__(
101+
self,
102+
hidden_states: torch.Tensor,
103+
forward_batch: ForwardBatch,
104+
qkv_latent_func: Callable,
105+
):
106+
self.hidden_states_local = hidden_states
107+
self.forward_batch = forward_batch
108+
self.qkv_latent_func = qkv_latent_func
109+
self.hidden_states_ = None
110+
self.qkv_latent_ = None
111+
112+
def tp_all_gather_hidden_states(self, hidden_states, forward_batch):
113+
total_tokens = forward_batch.input_ids.shape[0]
114+
output = hidden_states.new_empty((total_tokens, hidden_states.shape[-1]))
115+
get_tp_group().all_gather_into_tensor(output, hidden_states)
116+
return output
117+
118+
def fetch_qkv_latent(self):
119+
if self.qkv_latent_ is not None:
120+
return self.qkv_latent_
121+
assert self.qkv_latent_func is not None
122+
self.qkv_latent_ = self.qkv_latent_func(
123+
self.hidden_states_local, self.forward_batch
124+
)
125+
if get_attn_tp_context().input_scattered:
126+
self.qkv_latent_ = self.tp_all_gather_hidden_states(
127+
self.qkv_latent_, self.forward_batch
128+
)
129+
return self.qkv_latent_
130+
131+
def fetch_hidden_states(self):
132+
if self.hidden_states_ is not None:
133+
return self.hidden_states_
134+
self.hidden_states_ = self.hidden_states_local
135+
if get_attn_tp_context().input_scattered:
136+
self.hidden_states_ = self.tp_all_gather_hidden_states(
137+
self.hidden_states_, self.forward_batch
138+
)
139+
return self.hidden_states_
140+
141+
142+
class AttnTpContext:
143+
def __init__(self):
144+
self.allow_input_scattered = False
145+
self.input_scattered_ = False
146+
self.attn_inputs_: Optional[AttentionInputs] = None
147+
148+
def init_context(self, q_lora_rank, is_nsa):
149+
self.allow_input_scattered = (
150+
get_global_server_args().enable_attn_tp_input_scattered
151+
and _is_cuda
152+
and q_lora_rank is not None
153+
and not is_nsa
154+
and get_tensor_model_parallel_world_size() > 1
155+
and not is_dp_attention_enabled()
156+
and get_moe_a2a_backend().is_none()
157+
and not enable_moe_dense_fully_dp()
158+
and not get_global_server_args().enable_piecewise_cuda_graph
159+
and get_global_server_args().speculative_algorithm != "EAGLE3"
160+
)
161+
if get_global_server_args().enable_attn_tp_input_scattered:
162+
if not self.allow_input_scattered:
163+
logging.info(
164+
"attn_tp_input_scattered is not enabled while other conditions are not met"
165+
)
166+
else:
167+
logging.info("attn_tp_input_scattered is enabled")
168+
169+
def use_input_scattered(self, forward_batch: ForwardBatch):
170+
return (
171+
self.allow_input_scattered
172+
and forward_batch.forward_mode.is_extend()
173+
and not forward_batch.forward_mode.is_target_verify()
174+
and not forward_batch.forward_mode.is_draft_extend()
175+
and forward_batch.input_ids is not None
176+
and not forward_batch.can_run_tbo
177+
)
178+
179+
@property
180+
def input_scattered(self):
181+
return self.input_scattered_
182+
183+
def set_attn_inputs(self, attn_inputs: AttentionInputs):
184+
self.attn_inputs_ = attn_inputs
185+
186+
def fetch_qkv_latent(self):
187+
assert self.attn_inputs_ is not None
188+
return self.attn_inputs_.fetch_qkv_latent()
189+
190+
def fetch_hidden_states(self):
191+
assert self.attn_inputs_ is not None
192+
return self.attn_inputs_.fetch_hidden_states()
193+
194+
@contextmanager
195+
def maybe_input_scattered(self, forward_batch: ForwardBatch):
196+
flag = self.use_input_scattered(forward_batch)
197+
old_flag = self.input_scattered
198+
self.input_scattered_ = flag
199+
yield
200+
self.input_scattered_ = old_flag
201+
self.attn_inputs_ = None
202+
203+
204+
ATTN_TP_CONTEXT = AttnTpContext()
205+
206+
207+
def get_attn_tp_context():
208+
return ATTN_TP_CONTEXT
209+
210+
95211
@dataclass
96212
class _LayerModeComputationContext:
97213
num_layers: int
@@ -188,12 +304,14 @@ def __init__(
188304
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
189305
allow_reduce_scatter: bool = False,
190306
is_last_layer: bool = False,
307+
qkv_latent_func: Optional[Callable] = None,
191308
):
192309
self.layer_scatter_modes = layer_scatter_modes
193310
self.input_layernorm = input_layernorm
194311
self.post_attention_layernorm = post_attention_layernorm
195312
self.allow_reduce_scatter = allow_reduce_scatter
196313
self.is_last_layer = is_last_layer
314+
self.qkv_latent_func = qkv_latent_func
197315

198316
self._context = CommunicateContext.init_new()
199317
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
@@ -252,6 +370,11 @@ def prepare_attn(
252370
forward_batch: ForwardBatch,
253371
quant_format: str = "",
254372
):
373+
if get_attn_tp_context().input_scattered:
374+
hidden_states, residual = self._tp_reduce_scatter(
375+
hidden_states,
376+
residual,
377+
)
255378
if hidden_states.shape[0] == 0:
256379
residual = hidden_states
257380
else:
@@ -335,9 +458,32 @@ def prepare_attn(
335458
forward_batch=forward_batch,
336459
context=self._context,
337460
)
338-
461+
if self.qkv_latent_func is not None:
462+
attn_inputs = AttentionInputs(
463+
hidden_states, forward_batch, self.qkv_latent_func
464+
)
465+
get_attn_tp_context().set_attn_inputs(attn_inputs)
339466
return hidden_states, residual
340467

468+
def _tp_reduce_scatter(
469+
self,
470+
hidden_states: torch.Tensor,
471+
residual: torch.Tensor,
472+
) -> Tuple[torch.Tensor, torch.Tensor]:
473+
if hidden_states.shape[0] == 0:
474+
return hidden_states, hidden_states
475+
assert (
476+
hidden_states.shape[0] % self._context.tp_size == 0
477+
), f"Expected total tokens {hidden_states.shape[0]} % tp_size {self._context.tp_size} to be 0"
478+
local_tokens = hidden_states.shape[0] // self._context.tp_size
479+
output = hidden_states.new_empty(local_tokens, *hidden_states.shape[1:])
480+
get_tp_group().reduce_scatter_tensor(output, hidden_states)
481+
if residual is not None:
482+
residual = residual.tensor_split(self._context.tp_size)[
483+
self._context.tp_rank
484+
]
485+
return output, residual
486+
341487
def prepare_mlp(
342488
self,
343489
hidden_states: torch.Tensor,
@@ -371,12 +517,17 @@ def postprocess_layer(
371517
)
372518

373519
def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
374-
return (
375-
self.allow_reduce_scatter
376-
and self._communicate_summable_tensor_pair_fn
520+
if not self.allow_reduce_scatter:
521+
return False
522+
if (
523+
self._communicate_summable_tensor_pair_fn
377524
is CommunicateSummableTensorPairFn._scatter_hidden_states
378525
and forward_batch.dp_padding_mode.is_max_len()
379-
)
526+
):
527+
return True
528+
if get_attn_tp_context().input_scattered and not self.is_last_layer:
529+
return True
530+
return False
380531

381532
def should_fuse_mlp_allreduce_with_next_layer(
382533
self, forward_batch: ForwardBatch
@@ -388,6 +539,9 @@ def should_fuse_mlp_allreduce_with_next_layer(
388539
):
389540
return False
390541

542+
if get_attn_tp_context().input_scattered:
543+
return False
544+
391545
batch_size = (
392546
forward_batch.input_ids.shape[0]
393547
if hasattr(forward_batch, "input_ids")
@@ -422,6 +576,7 @@ class CommunicateContext:
422576
attn_dp_size: int
423577
tp_size: int
424578
cache = None
579+
tp_rank: int
425580

426581
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
427582
return self.process_group_sizes[a] == self.process_group_sizes[b]
@@ -432,6 +587,7 @@ def init_new(cls):
432587
attn_tp_size = get_attention_tp_size()
433588
attn_dp_size = get_attention_dp_size()
434589
tp_size = get_tensor_model_parallel_world_size()
590+
tp_rank = get_tensor_model_parallel_rank()
435591
process_group_sizes = {
436592
ScatterMode.SCATTERED: 1,
437593
ScatterMode.TP_ATTN_FULL: attn_tp_size,
@@ -444,6 +600,7 @@ def init_new(cls):
444600
attn_tp_size=attn_tp_size,
445601
attn_dp_size=attn_dp_size,
446602
tp_size=tp_size,
603+
tp_rank=tp_rank,
447604
)
448605

449606

@@ -566,6 +723,14 @@ def _gather_hidden_states_and_residual(
566723
*,
567724
residual_input_mode,
568725
):
726+
if get_attn_tp_context().input_scattered:
727+
return CommunicateWithAllReduceAndLayerNormFn._tp_all_reduce_with_scattered_residual(
728+
hidden_states,
729+
residual,
730+
layernorm,
731+
context,
732+
)
733+
569734
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
570735
residual, local_residual = (
571736
get_local_dp_buffer(),
@@ -637,6 +802,22 @@ def _scatter_hidden_states_and_residual(
637802
hidden_states, residual = layernorm(hidden_states, residual)
638803
return hidden_states, residual
639804

805+
@staticmethod
806+
def _tp_all_reduce_with_scattered_residual(
807+
hidden_states: torch.Tensor,
808+
residual: torch.Tensor,
809+
layernorm: torch.nn.Module,
810+
context: CommunicateContext,
811+
):
812+
if hidden_states.shape[0] == 0:
813+
return hidden_states, hidden_states
814+
815+
scattered_states = hidden_states.tensor_split(context.tp_size)[context.tp_rank]
816+
scattered_states += residual
817+
residual = tensor_model_parallel_all_reduce(hidden_states)
818+
hidden_states = layernorm(residual)
819+
return hidden_states, residual
820+
640821

641822
class CommunicateSummableTensorPairFn:
642823
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""

python/sglang/srt/layers/vocab_parallel_embedding.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use_symmetric_memory,
1919
)
2020
from sglang.srt.layers.amx_utils import PackWeightMethod
21+
from sglang.srt.layers.communicator import get_attn_tp_context
2122
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
2223
from sglang.srt.layers.parameter import BasevLLMParameter
2324
from sglang.srt.layers.quantization.base_config import (
@@ -478,11 +479,10 @@ def forward(self, input_):
478479
# Mask the output embedding.
479480
if self.tp_size > 1:
480481
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
481-
# Reduce across all the model parallel GPUs.
482-
output = tensor_model_parallel_all_reduce(output_parallel)
483-
else:
484-
output = output_parallel
485-
return output
482+
if not get_attn_tp_context().input_scattered:
483+
# Reduce across all the model parallel GPUs.
484+
output_parallel = tensor_model_parallel_all_reduce(output_parallel)
485+
return output_parallel
486486

487487
def extra_repr(self) -> str:
488488
s = f"num_embeddings={self.num_embeddings_per_partition}"

python/sglang/srt/model_executor/forward_batch_info.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@
3838
import triton
3939
import triton.language as tl
4040

41-
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
41+
from sglang.srt.distributed.parallel_state import (
42+
get_moe_expert_parallel_world_size,
43+
get_tensor_model_parallel_world_size,
44+
)
4245
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
4346
from sglang.srt.layers.dp_attention import (
4447
DpPaddingMode,
@@ -766,6 +769,13 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
766769
else:
767770
bs = self.batch_size = num_tokens
768771

772+
# padding
773+
self._pad_inputs_to_size(model_runner, num_tokens, bs)
774+
self.global_num_tokens_cpu = global_num_tokens
775+
global_num_tokens_pinned = torch.tensor(global_num_tokens, pin_memory=True)
776+
self.global_num_tokens_gpu.copy_(global_num_tokens_pinned, non_blocking=True)
777+
778+
def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs):
769779
# padding
770780
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
771781
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
@@ -788,9 +798,6 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
788798
if self.encoder_lens is not None:
789799
self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
790800
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
791-
self.global_num_tokens_cpu = global_num_tokens
792-
global_num_tokens_pinned = torch.tensor(global_num_tokens, pin_memory=True)
793-
self.global_num_tokens_gpu.copy_(global_num_tokens_pinned, non_blocking=True)
794801

795802
if self.mrope_positions is not None:
796803
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
@@ -818,6 +825,19 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
818825
spec_info.hidden_states, num_tokens
819826
)
820827

828+
def prepare_attn_tp_scatter_input(self, model_runner: ModelRunner):
829+
from sglang.srt.layers.communicator import get_attn_tp_context
830+
831+
attn_tp_context = get_attn_tp_context()
832+
input_scattered = attn_tp_context.use_input_scattered(self)
833+
if not input_scattered:
834+
return
835+
assert self.forward_mode.is_extend()
836+
tokens = self.input_ids.shape[0]
837+
rank_size = get_tensor_model_parallel_world_size()
838+
tokens_padded = (tokens + rank_size - 1) // rank_size * rank_size
839+
self._pad_inputs_to_size(model_runner, tokens_padded, self.batch_size)
840+
821841
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
822842

823843
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)

0 commit comments

Comments
 (0)