Skip to content

Commit 2340798

Browse files
authored
Register allgather/reducescatter buffers with symm memory (#12572)
1 parent 1357ab0 commit 2340798

File tree

19 files changed

+250
-114
lines changed

19 files changed

+250
-114
lines changed

python/sglang/srt/distributed/device_communicators/pynccl_allocator.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import os
22
import tempfile
3+
from contextlib import nullcontext
34

45
import torch
56
from torch.cuda.memory import CUDAPluggableAllocator
67

78
from sglang.srt.distributed.parallel_state import GroupCoordinator
8-
from sglang.srt.server_args import get_global_server_args
99

1010
nccl_allocator_source = """
1111
@@ -60,6 +60,9 @@
6060

6161

6262
def is_symmetric_memory_enabled():
63+
# Import here to avoid circular import
64+
from sglang.srt.server_args import get_global_server_args
65+
6366
return get_global_server_args().enable_symm_mem
6467

6568

@@ -92,33 +95,25 @@ def get_nccl_mem_pool():
9295
return _mem_pool
9396

9497

95-
class use_symmetric_memory:
98+
class SymmetricMemoryContext:
9699
"""
97100
Context manager for using symmetric memory with pynccl.
98101
99102
To Utilize the symmetric memory feature in NCCL, the buffers need to be allocated
100103
by `ncclMemAlloc` and registered by `ncclCommWindowRegister`. Due to this, we introduce
101104
this context manager. All tensors created under this context will be correctly
102105
allocated and registered with a custom allocator.
103-
104-
In addition, developers need to manually tag the tensors that will be used as the input/output
105-
of NCCL collectives with `tag(tensor)`.
106106
"""
107107

108-
def __init__(self, group_coordinator: GroupCoordinator):
109-
self.enabled = is_symmetric_memory_enabled()
110-
111-
if not self.enabled:
112-
return
113-
108+
def __init__(
109+
self,
110+
group_coordinator: GroupCoordinator,
111+
):
114112
self.group_coordinator = group_coordinator
115113
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
116114
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
117115

118116
def __enter__(self):
119-
if not self.enabled:
120-
return self
121-
122117
assert (
123118
self.group_coordinator.pynccl_comm is not None
124119
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
@@ -139,16 +134,16 @@ def __enter__(self):
139134
return self
140135

141136
def __exit__(self, exc_type, exc_val, exc_tb):
142-
if not self.enabled:
143-
return
144-
145137
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
146138

147139
if self.is_graph_capture:
148140
torch._C._cuda_beginAllocateCurrentThreadToPool(_cur_device, _graph_pool_id)
149141

150-
def tag(self, tensor: torch.Tensor):
151-
if not self.enabled:
152-
return
153142

154-
tensor.symmetric_memory = True
143+
def use_symmetric_memory(group_coordinator: GroupCoordinator, disabled: bool = False):
144+
disabled = (
145+
not is_symmetric_memory_enabled()
146+
or disabled
147+
or group_coordinator.world_size == 1
148+
)
149+
return SymmetricMemoryContext(group_coordinator) if not disabled else nullcontext()

python/sglang/srt/distributed/parallel_state.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,27 @@ def reg_all_gather_into_tensor_fake(
188188
fake_impl=reg_all_gather_into_tensor_fake,
189189
)
190190

191+
def reg_reduce_scatter_tensor(
192+
output: torch.Tensor, input: torch.Tensor, group_name: str
193+
) -> None:
194+
assert group_name in _groups, f"Group {group_name} is not found."
195+
group = _groups[group_name]()
196+
if group is None:
197+
raise ValueError(f"Group {group_name} is destroyed.")
198+
group._reduce_scatter_tensor(output, input)
199+
200+
def reg_reduce_scatter_tensor_fake(
201+
output: torch.Tensor, input: torch.Tensor, group_name: str
202+
) -> None:
203+
pass
204+
205+
direct_register_custom_op(
206+
op_name="reg_reduce_scatter_tensor",
207+
op_func=reg_reduce_scatter_tensor,
208+
mutates_args=["output"],
209+
fake_impl=reg_reduce_scatter_tensor_fake,
210+
)
211+
191212

192213
class GroupCoordinator:
193214
"""
@@ -314,10 +335,16 @@ def __init__(
314335
from sglang.srt.distributed.device_communicators.pynccl import (
315336
PyNcclCommunicator,
316337
)
338+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
339+
is_symmetric_memory_enabled,
340+
use_symmetric_memory,
341+
)
317342
from sglang.srt.distributed.device_communicators.torch_symm_mem import (
318343
TorchSymmMemCommunicator,
319344
)
320345

346+
self.is_symmetric_memory_enabled = is_symmetric_memory_enabled
347+
self.use_symmetric_memory = use_symmetric_memory
321348
if is_hip():
322349
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
323350
QuickAllReduce,
@@ -552,7 +579,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
552579
if self.npu_communicator is not None and not self.npu_communicator.disabled:
553580
return self.npu_communicator.all_reduce(input_)
554581

555-
if self.pynccl_comm is not None and getattr(input_, "symmetric_memory", False):
582+
if self.pynccl_comm is not None and self.is_symmetric_memory_enabled():
556583
with self.pynccl_comm.change_state(
557584
enable=True, stream=get_current_device_stream_fast()
558585
):
@@ -627,15 +654,33 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
627654
else:
628655
torch.distributed.all_reduce(input_, group=self.device_group)
629656

630-
def reduce_scatter_tensor(
657+
def _reduce_scatter_tensor(
631658
self,
632659
output: torch.Tensor,
633660
input: torch.Tensor,
634-
) -> None:
635-
# TODO(ch-wan): support other backends
636-
torch.distributed.reduce_scatter_tensor(output, input, group=self.device_group)
661+
) -> torch.Tensor:
662+
pynccl_comm = self.pynccl_comm
663+
if pynccl_comm is not None and (
664+
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
665+
):
666+
with pynccl_comm.change_state(
667+
enable=True, stream=get_current_device_stream_fast()
668+
):
669+
pynccl_comm.reduce_scatter(output, input)
670+
else:
671+
torch.distributed.reduce_scatter_tensor(
672+
output, input, group=self.device_group
673+
)
637674
return output
638675

676+
def reduce_scatter_tensor(self, output: torch.Tensor, input: torch.Tensor):
677+
if _is_npu or not supports_custom_op():
678+
self._reduce_scatter_tensor(output, input)
679+
else:
680+
torch.ops.sglang.reg_reduce_scatter_tensor(
681+
output, input, group_name=self.unique_name
682+
)
683+
639684
def reduce_scatter(
640685
self,
641686
output: torch.Tensor,
@@ -682,8 +727,13 @@ def reduce_scatterv(
682727

683728
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
684729
pynccl_comm = self.pynccl_comm
685-
if pynccl_comm is not None and not pynccl_comm.disabled:
686-
pynccl_comm.all_gather(output, input)
730+
if pynccl_comm is not None and (
731+
not pynccl_comm.disabled or self.is_symmetric_memory_enabled()
732+
):
733+
with pynccl_comm.change_state(
734+
enable=True, stream=get_current_device_stream_fast()
735+
):
736+
pynccl_comm.all_gather(output, input)
687737
else:
688738
torch.distributed.all_gather_into_tensor(
689739
output, input, group=self.device_group
@@ -745,9 +795,10 @@ def all_gather(
745795
# torch.compile . see https://github.com/pytorch/pytorch/issues/138795
746796
output_size = (input_size[0] * world_size,) + input_size[1:]
747797
# Allocate output tensor.
748-
output_tensor = torch.empty(
749-
output_size, dtype=input_.dtype, device=input_.device
750-
)
798+
with self.use_symmetric_memory(self):
799+
output_tensor = torch.empty(
800+
output_size, dtype=input_.dtype, device=input_.device
801+
)
751802

752803
# All-gather.
753804
if input_.is_cpu:
@@ -787,7 +838,7 @@ def all_gatherv(
787838
pynccl_comm is not None and not pynccl_comm.disabled
788839
), "pynccl is required for all_gatherv"
789840

790-
def _all_gather_single(
841+
def _all_gather_allocate_output(
791842
input_: torch.Tensor, sizes: Optional[List[int]] = None
792843
):
793844
input_size = input_.size()
@@ -801,19 +852,25 @@ def _all_gather_single(
801852
else:
802853
output_size = (input_size[0] * world_size,) + input_size[1:]
803854
# Allocate output tensor.
804-
output_tensor = torch.empty(
805-
output_size, dtype=input_.dtype, device=input_.device
806-
)
807-
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
808-
return output_tensor
855+
with self.use_symmetric_memory(self, disabled=sizes is not None):
856+
output_tensor = torch.empty(
857+
output_size, dtype=input_.dtype, device=input_.device
858+
)
859+
return output_tensor, sizes
809860

810861
if isinstance(input_, torch.Tensor):
811-
return _all_gather_single(input_, sizes)
862+
input_ = [input_]
812863

813864
output_list = []
814-
pynccl_comm.group_start()
865+
size_list = []
815866
for inp in input_:
816-
output_list.append(_all_gather_single(inp, sizes=sizes))
867+
output_tensor, s = _all_gather_allocate_output(inp, sizes=sizes)
868+
output_list.append(output_tensor)
869+
size_list.append(s)
870+
871+
pynccl_comm.group_start()
872+
for i, inp in enumerate(input_):
873+
pynccl_comm.all_gather(output_list[i], inp, sizes=size_list[i])
817874
pynccl_comm.group_end()
818875

819876
return output_list

python/sglang/srt/layers/communicator.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121

2222
from sglang.srt.distributed import (
2323
get_tensor_model_parallel_world_size,
24+
get_tp_group,
2425
tensor_model_parallel_all_reduce,
2526
)
27+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
28+
use_symmetric_memory,
29+
)
2630
from sglang.srt.layers.dp_attention import (
2731
attn_tp_all_gather_into_tensor,
2832
attn_tp_reduce_scatter_tensor,
@@ -34,6 +38,7 @@
3438
get_attention_tp_size,
3539
get_global_dp_buffer,
3640
get_local_dp_buffer,
41+
is_allocation_symmetric,
3742
is_dp_attention_enabled,
3843
)
3944
from sglang.srt.layers.moe import (
@@ -540,7 +545,12 @@ def _gather_hidden_states_and_residual(
540545
use_layer_norm_before_gather = context.attn_tp_size == 1
541546
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
542547
residual = hidden_states
543-
hidden_states = layernorm(hidden_states)
548+
with use_symmetric_memory(
549+
get_tp_group(),
550+
disabled=not is_allocation_symmetric(),
551+
):
552+
hidden_states = layernorm(hidden_states)
553+
544554
hidden_states, local_hidden_states = (
545555
get_global_dp_buffer(),
546556
hidden_states,

python/sglang/srt/layers/dp_attention.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
get_tp_group,
1818
tensor_model_parallel_all_reduce,
1919
)
20+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
21+
use_symmetric_memory,
22+
)
2023
from sglang.srt.utils import get_bool_env_var, is_hip
2124

2225
if TYPE_CHECKING:
@@ -86,6 +89,7 @@ class _DpGatheredBufferWrapper:
8689
_device: torch.device
8790
_global_dp_buffer_len: int
8891
_local_dp_buffer_len: int
92+
_dp_max_padding: bool
8993
_global_num_tokens: Optional[List[int]]
9094
_is_extend_in_batch: bool
9195

@@ -100,27 +104,33 @@ def set_dp_buffer_len(
100104
cls,
101105
global_dp_buffer_len: int,
102106
local_dp_buffer_len: int,
107+
dp_max_padding: bool,
103108
global_num_tokens: Optional[List[int]] = None,
104109
):
105110
cls._global_dp_buffer_len = global_dp_buffer_len
106111
cls._local_dp_buffer_len = local_dp_buffer_len
112+
cls._dp_max_padding = dp_max_padding
107113
cls._global_num_tokens = global_num_tokens
108114

109115
@classmethod
110116
def get_global_dp_buffer(cls) -> torch.Tensor:
111-
return torch.empty(
112-
(cls._global_dp_buffer_len, cls._hidden_size),
113-
dtype=cls._dtype,
114-
device=cls._device,
115-
)
117+
with use_symmetric_memory(get_tp_group()):
118+
buffer = torch.empty(
119+
(cls._global_dp_buffer_len, cls._hidden_size),
120+
dtype=cls._dtype,
121+
device=cls._device,
122+
)
123+
return buffer
116124

117125
@classmethod
118126
def get_local_dp_buffer(cls) -> torch.Tensor:
119-
return torch.empty(
120-
(cls._local_dp_buffer_len, cls._hidden_size),
121-
dtype=cls._dtype,
122-
device=cls._device,
123-
)
127+
with use_symmetric_memory(get_tp_group(), disabled=not cls._dp_max_padding):
128+
buffer = torch.empty(
129+
(cls._local_dp_buffer_len, cls._hidden_size),
130+
dtype=cls._dtype,
131+
device=cls._device,
132+
)
133+
return buffer
124134

125135
@classmethod
126136
def get_global_dp_buffer_len(cls) -> int:
@@ -154,14 +164,19 @@ def set_is_extend_in_batch(cls, is_extend_in_batch: bool):
154164
def get_is_extend_in_batch(cls) -> bool:
155165
return cls._is_extend_in_batch
156166

167+
@classmethod
168+
def is_dp_max_padding(cls) -> bool:
169+
return cls._dp_max_padding
170+
157171

158172
def set_dp_buffer_len(
159173
global_dp_buffer_len: int,
160174
local_dp_buffer_len: int,
175+
dp_max_padding: bool,
161176
global_num_tokens: Optional[List[int]] = None,
162177
):
163178
_DpGatheredBufferWrapper.set_dp_buffer_len(
164-
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
179+
global_dp_buffer_len, local_dp_buffer_len, dp_max_padding, global_num_tokens
165180
)
166181

167182

@@ -205,6 +220,10 @@ def get_is_extend_in_batch() -> bool:
205220
return _DpGatheredBufferWrapper.get_is_extend_in_batch()
206221

207222

223+
def is_dp_max_padding() -> bool:
224+
return _DpGatheredBufferWrapper.is_dp_max_padding()
225+
226+
208227
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
209228
if not enable_dp_attention:
210229
return tp_rank, tp_size, 0
@@ -298,6 +317,10 @@ def is_dp_attention_enabled() -> bool:
298317
return _ENABLE_DP_ATTENTION_FLAG
299318

300319

320+
def is_allocation_symmetric() -> bool:
321+
return not is_dp_attention_enabled() or is_dp_max_padding()
322+
323+
301324
def get_attention_tp_group() -> GroupCoordinator:
302325
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
303326
return _ATTN_TP_GROUP

0 commit comments

Comments
 (0)