Skip to content

Commit 8a71d0b

Browse files
MrGevacodego7250
authored andcommitted
[NVIDIA#8921][feat] Added symetric memory AllReduce strategy (NVIDIA#8919)
Signed-off-by: Eran Geva <[email protected]>
1 parent e3ee462 commit 8a71d0b

File tree

4 files changed

+306
-7
lines changed

4 files changed

+306
-7
lines changed

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch
88
from torch import nn
99

10+
from tensorrt_llm._torch.distributed.symm_mem_allreduce import \
11+
SymmetricMemoryAllReduce
1012
from tensorrt_llm._utils import mpi_comm, mpi_disabled
1113
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
1214
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
@@ -567,13 +569,17 @@ def __init__(self,
567569
strategy (AllReduceStrategy):
568570
The following all-reduce strategies are supported:
569571
572+
- SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions.
573+
Falls back automatically if not supported.
574+
570575
- UB: AllReduce uses user-buffer based all-reduce kernel.
571576
572577
- NCCL: Use NCCL allreduce.
573578
574579
- MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel.
575580
576-
- AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy.
581+
- AUTO: AUTO chooses the best available strategy. Will try MNNVL,
582+
then choose between NCCL and MIN_LATENCY based on a heuristic policy.
577583
578584
- LOWPRECISION: AllReduce quantizes data to lower precision for transmission.
579585
Should only be used on topologies with PCIe switches and without NVLink.
@@ -602,12 +608,42 @@ def __init__(self,
602608
self.workspace = None
603609
self.strategy = strategy
604610
self.mnnvl_allreduce = None
611+
self.symm_mem_allreduce = None
605612
self._disable_mpi = mpi_disabled()
606613

607614
self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce
608615

609616
if self.mapping.tp_size > 1:
610-
# When Strategy is UB, it is guaranteed that the workspace is not used.
617+
# Initialize Symmetric Memory AllReduce if needed (before workspace allocation)
618+
if self.strategy == AllReduceStrategy.SYMM_MEM:
619+
try:
620+
symm_mem = SymmetricMemoryAllReduce(
621+
self.mapping,
622+
dtype=dtype if dtype else torch.bfloat16,
623+
)
624+
if not symm_mem.disabled:
625+
self.symm_mem_allreduce = symm_mem
626+
logger.info(
627+
f"SymmetricMemoryAllReduce (MULTIMEM) is enabled with fallback support for world_size={self.mapping.tp_size}"
628+
)
629+
# Keep SYMM_MEM strategy but allocate workspace for fallback to regular allreduce
630+
else:
631+
logger.info(
632+
f"SymmetricMemoryAllReduce is disabled (not supported or unavailable), falling back to AUTO strategy"
633+
)
634+
# Fall back to AUTO if SYMM_MEM can't be enabled
635+
self.strategy = AllReduceStrategy.AUTO
636+
except Exception as e:
637+
logger.info(
638+
f"Symmetric Memory AllReduce can't be enabled due to {e}, falling back to AUTO strategy"
639+
)
640+
self.symm_mem_allreduce = None
641+
# Fall back to AUTO if SYMM_MEM initialization fails
642+
self.strategy = AllReduceStrategy.AUTO
643+
644+
# Allocate workspace for strategies that need it
645+
# Note: SYMM_MEM now also needs workspace for fallback scenarios (fused ops, etc.)
646+
# Only UB doesn't need workspace
611647
if self.strategy != AllReduceStrategy.UB:
612648
if self.strategy == AllReduceStrategy.LOWPRECISION:
613649
allocate_low_presicion_allreduce_workspace(self.mapping)
@@ -616,9 +652,10 @@ def __init__(self,
616652
AllReduceStrategy.NCCL_SYMMETRIC):
617653
self.workspace = get_allreduce_workspace(self.mapping)
618654

619-
# Initialize MNNVL AllReduce if needed
655+
# Initialize MNNVL if using AUTO or MNNVL strategy
620656
if self.strategy in (AllReduceStrategy.AUTO,
621657
AllReduceStrategy.MNNVL):
658+
# Try to initialize MNNVL
622659
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
623660
# ALWAYS capture the exception when creating this instance
624661
try:
@@ -674,20 +711,39 @@ def forward(
674711
if all_reduce_params is None:
675712
all_reduce_params = AllReduceParams()
676713

677-
# Try MNNVL AllReduce first if available
714+
# Try Symmetric Memory AllReduce first if available
715+
# Note: Currently only supports NONE fusion op (plain allreduce)
716+
if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE:
717+
symm_mem_output = self.symm_mem_allreduce(input)
718+
if symm_mem_output is not None:
719+
logger.debug(
720+
f"Using SymmetricMemoryAllReduce (MULTIMEM) for input shape {input.shape}"
721+
)
722+
return symm_mem_output
723+
elif self.symm_mem_allreduce and all_reduce_params.fusion_op != AllReduceFusionOp.NONE:
724+
# Log once per rank that we're skipping symm_mem due to fusion
725+
logger.debug_once(
726+
f"Skipping SymmetricMemoryAllReduce for fused operation (fusion_op={all_reduce_params.fusion_op}), using regular allreduce",
727+
key=(self.mapping.tp_rank, all_reduce_params.fusion_op,
728+
"debug_fusion_skip"),
729+
)
730+
731+
# Try MNNVL AllReduce if symm_mem didn't handle it
678732
if self.mnnvl_allreduce:
679733
mnnvl_output = self.mnnvl_allreduce(
680734
input, all_reduce_params=all_reduce_params)
681735
if mnnvl_output is not None:
682736
return mnnvl_output
683737

684-
# Fall back to regular AllReduce if MNNVL is not available or not applicable
685-
# Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL
686-
if allreduce_strategy == AllReduceStrategy.MNNVL:
738+
# Fall back to regular AllReduce if specialized methods are not available or not applicable
739+
# Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL/SYMM_MEM
740+
if allreduce_strategy in (AllReduceStrategy.MNNVL,
741+
AllReduceStrategy.SYMM_MEM):
687742
allreduce_strategy = AllReduceStrategy.AUTO
688743

689744
additional_args = {}
690745
if self._disable_mpi:
746+
# Get ProcessGroup from mapping
691747
pg = self.mapping.tp_group_pg
692748
assert pg is not None, "TP ProcessGroup not initialised"
693749
additional_args = {
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
"""
4+
Symmetric Memory AllReduce
5+
6+
This module provides PyTorch Symmetric Memory-based allreduce operations,
7+
leveraging MULTIMEM hardware instructions.
8+
"""
9+
10+
from typing import Optional
11+
12+
import torch
13+
import torch.distributed as dist
14+
from torch import nn
15+
16+
from tensorrt_llm.logger import logger
17+
from tensorrt_llm.mapping import Mapping
18+
19+
try:
20+
import torch.distributed._symmetric_memory as torch_symm_mem
21+
22+
SYMM_MEM_AVAILABLE = True
23+
except ImportError:
24+
SYMM_MEM_AVAILABLE = False
25+
logger.warning(
26+
"PyTorch symmetric memory not available. Install PyTorch >= 2.8 for MULTIMEM support."
27+
)
28+
29+
30+
class SymmetricMemoryAllReduce(nn.Module):
31+
"""
32+
AllReduce implementation using PyTorch's symmetric memory operations.
33+
This leverages MULTIMEM hardware instructions for faster allreduce operations.
34+
35+
Supported configurations (world_size):
36+
- SM 9.0: 4, 6, 8 GPUs
37+
- SM 10.0: 6, 8 GPUs
38+
39+
"""
40+
41+
# World sizes that support MULTIMEM instructions
42+
_WORLD_SIZES_MULTIMEM = {
43+
"9.0": [4, 6, 8],
44+
"10.0": [6, 8],
45+
}
46+
47+
MiB = 1024 * 1024
48+
# Maximum buffer sizes for symmetric memory (bytes)
49+
_MAX_SIZES = {
50+
"9.0": {
51+
2: 64 * MiB, # 64 MB
52+
4: 32 * MiB, # 32 MB
53+
6: 64 * MiB, # 64 MB
54+
8: 64 * MiB, # 64 MB
55+
},
56+
"10.0": {
57+
2: 8 * MiB, # 8 MB
58+
4: 32 * MiB, # 32 MB
59+
6: 128 * MiB, # 128 MB
60+
8: 128 * MiB, # 128 MB
61+
},
62+
}
63+
64+
def __init__(
65+
self,
66+
mapping: Mapping,
67+
dtype: torch.dtype = torch.bfloat16,
68+
group: Optional[dist.ProcessGroup] = None,
69+
):
70+
super().__init__()
71+
72+
self.disabled = True
73+
self.mapping = mapping
74+
self.dtype = dtype
75+
self.world_size = mapping.tp_size
76+
77+
if not SYMM_MEM_AVAILABLE:
78+
logger.warning("SymmetricMemoryAllReduce: PyTorch symm_mem not available")
79+
return
80+
81+
if not torch.cuda.is_available():
82+
logger.warning("SymmetricMemoryAllReduce: CUDA not available")
83+
return
84+
85+
# Get device capability
86+
device = torch.device(f"cuda:{mapping.tp_rank}")
87+
capability = torch.cuda.get_device_capability(device)
88+
self.device_capability = f"{capability[0]}.{capability[1]}"
89+
90+
# Check if this configuration is supported
91+
if self.device_capability not in self._MAX_SIZES:
92+
logger.warning(
93+
f"SymmetricMemoryAllReduce: Device capability {self.device_capability} not supported"
94+
)
95+
return
96+
97+
if self.world_size not in self._MAX_SIZES[self.device_capability]:
98+
logger.info(
99+
f"SymmetricMemoryAllReduce: World size {self.world_size} not supported "
100+
f"for SM {self.device_capability}"
101+
)
102+
return
103+
104+
# Get max buffer size for this configuration
105+
self.max_size = self._MAX_SIZES[self.device_capability][self.world_size]
106+
107+
# Set up process group
108+
self.group = group
109+
if self.group is None:
110+
# Get or create TP group with correct ranks
111+
# For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally
112+
# NOT starting from tp_rank!
113+
if not dist.is_initialized():
114+
logger.warning("SymmetricMemoryAllReduce: torch.distributed not initialized")
115+
self.disabled = True
116+
return
117+
# Get actual TP group ranks from mapping (tp_group is a property, not a method)
118+
tp_group_ranks = mapping.tp_group
119+
self.group = dist.new_group(tp_group_ranks) if len(tp_group_ranks) > 1 else None
120+
121+
# Enable symmetric memory for this group
122+
try:
123+
# Get group_name - this may fail if ProcessGroup doesn't have group_name set
124+
if not hasattr(self.group, "group_name"):
125+
logger.warning(
126+
"SymmetricMemoryAllReduce: ProcessGroup does not have group_name attribute"
127+
)
128+
self.disabled = True
129+
return
130+
131+
group_name_str = str(self.group.group_name)
132+
torch_symm_mem.enable_symm_mem_for_group(group_name_str)
133+
logger.debug(
134+
f"SymmetricMemoryAllReduce: Enabled symmetric memory for group {group_name_str}"
135+
)
136+
except Exception as e:
137+
logger.warning(
138+
f"SymmetricMemoryAllReduce: Failed to enable symmetric memory for group: {e}"
139+
)
140+
self.disabled = True
141+
return
142+
143+
# Allocate symmetric memory buffer
144+
try:
145+
self.buffer = torch_symm_mem.empty(
146+
self.max_size // self.dtype.itemsize,
147+
device=device,
148+
dtype=self.dtype,
149+
)
150+
# Pass group name string
151+
group_name_str = str(self.group.group_name)
152+
handle = torch_symm_mem.rendezvous(self.buffer, group_name_str)
153+
154+
if handle.multicast_ptr == 0:
155+
logger.warning(
156+
"SymmetricMemoryAllReduce: MULTIMEM operations not supported (multicast_ptr is 0)"
157+
)
158+
return
159+
160+
# Only enable if MULTIMEM is supported
161+
# Otherwise, no benefit over existing TensorRT-LLM strategies
162+
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM.get(
163+
self.device_capability, []
164+
)
165+
166+
if not use_multimem:
167+
logger.info(
168+
f"SymmetricMemoryAllReduce: MULTIMEM not supported for "
169+
f"world_size={self.world_size}, SM={self.device_capability}. "
170+
f"Falling back to standard allreduce strategies."
171+
)
172+
return
173+
174+
self.disabled = False
175+
logger.info(
176+
f"SymmetricMemoryAllReduce (MULTIMEM) initialized: "
177+
f"world_size={self.world_size}, "
178+
f"max_size={self.max_size}, "
179+
f"SM={self.device_capability}"
180+
)
181+
182+
except Exception as e:
183+
logger.warning(f"SymmetricMemoryAllReduce initialization failed: {e}")
184+
return
185+
186+
@property
187+
def process_group(self) -> Optional[dist.ProcessGroup]:
188+
"""Expose the ProcessGroup for use in fallback scenarios."""
189+
return self.group if not self.disabled else None
190+
191+
def can_use_symm_mem(self, inp: torch.Tensor) -> bool:
192+
"""Check if symmetric memory can be used for this tensor."""
193+
if self.disabled:
194+
return False
195+
if inp.dtype != self.dtype:
196+
return False
197+
inp_size = inp.numel() * inp.element_size()
198+
if inp_size % 4 != 0:
199+
return False
200+
if inp_size >= self.max_size:
201+
return False
202+
return True
203+
204+
def forward(
205+
self,
206+
inp: torch.Tensor,
207+
out: Optional[torch.Tensor] = None,
208+
) -> torch.Tensor:
209+
"""
210+
Perform allreduce using symmetric memory operations.
211+
212+
Args:
213+
inp: Input tensor to reduce
214+
out: Optional output tensor (if None, will be allocated)
215+
216+
Returns:
217+
Reduced tensor
218+
"""
219+
if not self.can_use_symm_mem(inp):
220+
return None # Caller should fall back to other strategy
221+
222+
if out is None:
223+
out = torch.empty_like(inp)
224+
225+
# Copy input to symmetric memory buffer
226+
self.buffer[: inp.numel()].copy_(inp.view(-1))
227+
228+
# Perform MULTIMEM allreduce
229+
# Pass group name string (matching vLLM's implementation)
230+
group_name_str = str(self.group.group_name)
231+
torch.ops.symm_mem.multimem_all_reduce_(
232+
self.buffer[: inp.numel()],
233+
"sum",
234+
group_name_str,
235+
)
236+
237+
# Copy result back
238+
out.copy_(self.buffer[: inp.numel()].view(out.shape))
239+
240+
return out

tensorrt_llm/functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3881,6 +3881,7 @@ class AllReduceStrategy(IntEnum):
38813881
LOWPRECISION = 6
38823882
MNNVL = 7
38833883
NCCL_SYMMETRIC = 8
3884+
SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM
38843885

38853886

38863887
class AllReduceFusionOp(IntEnum):

tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_allreduce_strategies.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def _prepare_dataset(root_dir: str, temp_dir: str, model_path_or_name: str, num_
115115
"TWOSHOT",
116116
"MIN_LATENCY",
117117
"NCCL",
118+
"SYMM_MEM",
118119
],
119120
)
120121
def test_allreduce_strategies(llm_root, shared_dataset, allreduce_strategy): # noqa: F811
@@ -230,6 +231,7 @@ def test_allreduce_strategies(llm_root, shared_dataset, allreduce_strategy): #
230231
"NCCL",
231232
"TWOSHOT",
232233
"MIN_LATENCY",
234+
"SYMM_MEM",
233235
],
234236
)
235237
def test_allreduce_strategy_propagation(strategy):

0 commit comments

Comments
 (0)