Skip to content

Commit eaba87c

Browse files
committed
Reduce-scatter implementation with FP32 accumulation
- Include plumbing to use this implementation for actual model training - Add async_op=True option for reduce_scatter_with_fp32_accumulation - Add unit test Signed-off-by: Deepak Narayanan <[email protected]>
1 parent f6d1db9 commit eaba87c

File tree

6 files changed

+200
-2
lines changed

6 files changed

+200
-2
lines changed

megatron/core/distributed/distributed_data_parallel_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ class DistributedDataParallelConfig:
4949
message size (which for ring algorithms is bucket_size / dp_size) apparently needs
5050
to be divisible by a power of 2 for high busbw."""
5151

52+
reduce_scatter_with_fp32_accumulation: bool = False
53+
"""If true, use a reduce-scatter implementation which sends lower-precision values
54+
over the wire (using an all-to-all to keep total communication overhead in line
55+
with the standard ring implementation) but performs accumulation locally in FP32."""
56+
5257
average_in_collective: bool = False
5358
"""If true, compute average in collective directly, as opposed to dividing by the
5459
dp_size first and then computing sum in the collective."""

megatron/core/distributed/param_and_grad_buffer.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..fp8_utils import is_float8tensor, is_mxfp8tensor, modify_underlying_storage
2121
from ..utils import is_torch_min_version, log_on_each_pipeline_stage
2222
from .distributed_data_parallel_config import DistributedDataParallelConfig
23+
from .reduce_scatter_with_fp32_accumulation import reduce_scatter_with_fp32_accumulation
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -151,6 +152,13 @@ def __init__(
151152
if self.ddp_config.num_distributed_optimizer_instances > 1:
152153
self.inter_distributed_optimizer_instance_group = None
153154
self.communication_stream = None
155+
assert (
156+
not self.ddp_config.reduce_scatter_with_fp32_accumulation
157+
), "RS w/ FP32 accumulation not supported with num_distributed_optimizer_instances > 1"
158+
159+
global dist_reduce_scatter_func
160+
if self.ddp_config.reduce_scatter_with_fp32_accumulation:
161+
dist_reduce_scatter_func = reduce_scatter_with_fp32_accumulation
154162

155163
self.reset()
156164
self.param_gather_handle = None
@@ -382,6 +390,7 @@ def start_grad_sync(self):
382390
communication_group = self.data_parallel_group
383391

384392
# Coalesce communication kernels across buckets in the bucket group.
393+
grad_reduce_handle = None
385394
with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm:
386395
for idx, bucket in enumerate(self.buckets):
387396
if self.ddp_config.use_distributed_optimizer:
@@ -392,7 +401,7 @@ def start_grad_sync(self):
392401
local_data_view = self.cached_grad_buffer_shard_list[idx][
393402
self.intra_distributed_optimizer_instance_rank
394403
]
395-
dist_reduce_scatter_func(
404+
grad_reduce_handle = dist_reduce_scatter_func(
396405
local_data_view,
397406
bucket.grad_data,
398407
op=reduce_op,
@@ -434,7 +443,16 @@ def start_grad_sync(self):
434443
)
435444

436445
if async_op:
437-
self.grad_reduce_handle = cm
446+
if self.ddp_config.reduce_scatter_with_fp32_accumulation:
447+
assert (
448+
len(self.buckets) == 1
449+
), "Only 1 bucket supported with reduce_scatter_with_fp32_accumulation=True"
450+
# torch.distributed._coalescing_manager does not correctly handle calling our custom
451+
# collective handle's .wait() method, so we take matters into our own hands here.
452+
assert grad_reduce_handle is not None
453+
self.grad_reduce_handle = grad_reduce_handle
454+
else:
455+
self.grad_reduce_handle = cm
438456
else:
439457
# When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
440458
# `cm` is not None, which is different from when `_coalescing_manager` is not used in
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
3+
4+
from typing import Any
5+
6+
import torch
7+
8+
9+
class _ReduceScatterWithFP32AccumulationWorkHandle:
10+
"""Work handle to return to user when using reduce_scatter_with_fp32_accumulation with
11+
async_op=True."""
12+
13+
def __init__(
14+
self,
15+
all_to_all_handle: Any,
16+
all_to_all_output_tensor: torch.Tensor,
17+
output_tensor: torch.Tensor,
18+
world_size: int,
19+
):
20+
"""Initialize WorkHandle object."""
21+
self.all_to_all_handle = all_to_all_handle
22+
self.all_to_all_output_tensor = all_to_all_output_tensor
23+
self.output_tensor = output_tensor
24+
self.world_size = world_size
25+
26+
def wait(self):
27+
"""Wait until communication (and associated computation) is completed."""
28+
# Wait for communication to complete if needed.
29+
if self.all_to_all_handle is not None:
30+
self.all_to_all_handle.wait()
31+
32+
# Accumulate into a fp32 sum.
33+
output_tensor_in_fp32 = torch.sum(
34+
self.all_to_all_output_tensor.view((self.world_size, -1)), dim=0, dtype=torch.float32
35+
)
36+
assert output_tensor_in_fp32.dtype == torch.float32
37+
38+
# Copy downcasted sum into output_tensor.
39+
self.output_tensor.copy_(output_tensor_in_fp32)
40+
41+
42+
def reduce_scatter_with_fp32_accumulation(
43+
output_tensor: torch.Tensor,
44+
input_tensor: torch.Tensor,
45+
op: torch.distributed.ReduceOp,
46+
group: torch.distributed.ProcessGroup,
47+
async_op: bool,
48+
):
49+
"""Reduce-scatter with FP32 accumulation.
50+
51+
Collects input_tensor in lower precision using an all-to-all, then locally accumulates in FP32
52+
precision, then downcasts final sum back into right location in input_tensor.
53+
54+
55+
Args:
56+
output_tensor (torch.Tensor): Output tensor with reduce-scattered output (only the shard).
57+
input_tensor (torch.Tensor): Input tensor that needs to be reduce-scattered.
58+
op (torch.distributed.ReduceOp): Only torch.distributed.ReduceOp.SUM is supported.
59+
group (torch.distributed.ProcessGroup): Process group to use for reduce-scatter.
60+
async_op (bool): Only False is supported right now.
61+
"""
62+
# Make sure arguments conform to the implementation.
63+
assert op == torch.distributed.ReduceOp.SUM
64+
65+
# Get world_size.
66+
if group is None:
67+
world_size = torch.distributed.get_world_size()
68+
else:
69+
world_size = group.size()
70+
71+
# Make sure input_tensor size is divisible by world size.
72+
assert input_tensor.numel() % world_size == 0
73+
74+
# Call all_to_all (every rank should have their respective gradient shards collected from
75+
# all ranks). We also create a tensor for the all-to-all output (the all-to-all collective
76+
# cannot be performed in-place).
77+
all_to_all_output_tensor = torch.empty_like(input_tensor)
78+
all_to_all_handle = torch.distributed.all_to_all_single(
79+
output=all_to_all_output_tensor, input=input_tensor, group=group, async_op=async_op
80+
)
81+
82+
# Create a work handle to finish communication and reduction.
83+
reduce_scatter_handle = _ReduceScatterWithFP32AccumulationWorkHandle(
84+
all_to_all_handle, all_to_all_output_tensor, output_tensor, world_size
85+
)
86+
if async_op:
87+
# Return work handle; consumers can call .wait() to ensure communication and associated
88+
# reduction complete.
89+
return reduce_scatter_handle
90+
else:
91+
# Wait on work handle.
92+
reduce_scatter_handle.wait()

megatron/training/arguments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2636,6 +2636,10 @@ def _add_distributed_args(parser):
26362636
'of 2 (2^16) to ensure NCCL collectives have high bus bandwidth at large DP counts, '
26372637
'since NCCL message size (which for ring algorithms is bucket_size / dp_size) '
26382638
'apparently needs to be divisible by a power of 2 for high busbw.')
2639+
group.add_argument('--ddp-reduce-scatter-with-fp32-accumulation', action='store_true',
2640+
default=False, help='If set, use a reduce-scatter implementation which sends lower-precision '
2641+
'values over the wire (using an all-to-all to keep total communication overhead in line '
2642+
'with the standard ring implementation) but performs accumulation locally in FP32.')
26392643
group.add_argument('--ddp-average-in-collective', action='store_true',
26402644
default=False, help='If set, average directly in data-parallel communication collective.')
26412645
group.add_argument('--overlap-param-gather', action='store_true',

megatron/training/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ def build_model():
971971
else:
972972
kwargs['bucket_size'] = args.ddp_bucket_size
973973
kwargs['pad_buckets_for_high_nccl_busbw'] = args.ddp_pad_buckets_for_high_nccl_busbw
974+
kwargs['reduce_scatter_with_fp32_accumulation'] = args.ddp_reduce_scatter_with_fp32_accumulation
974975
kwargs['average_in_collective'] = args.ddp_average_in_collective
975976
if args.use_megatron_fsdp and args.use_precision_aware_optimizer:
976977
kwargs["preserve_fp32_weights"] = False
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
3+
4+
import pytest
5+
import torch
6+
7+
# Import our reduce_scatter implementation and shard_buffer (used for
8+
# checks in the test).
9+
from megatron.core.distributed.param_and_grad_buffer import (
10+
reduce_scatter_with_fp32_accumulation,
11+
shard_buffer,
12+
)
13+
from tests.unit_tests.test_utilities import Utils
14+
15+
16+
def get_non_matching_values(tensor1_shard, tensor2_shard):
17+
mask = torch.isclose(tensor1_shard, tensor2_shard)
18+
indices = (~mask).nonzero()
19+
return indices, tensor1_shard[indices], tensor2_shard[indices]
20+
21+
22+
class TestReduceScatterWithFP32Accumulation:
23+
@classmethod
24+
def setup_class(cls):
25+
Utils.initialize_model_parallel()
26+
27+
@classmethod
28+
def teardown_class(cls):
29+
Utils.destroy_model_parallel()
30+
31+
@pytest.mark.parametrize("async_op", [True, False])
32+
@pytest.mark.parametrize("baseline_reduce_scatter_in_fp32", [True, False])
33+
def test_reduce_scatter_with_fp32_accumulation(
34+
self, async_op: bool, baseline_reduce_scatter_in_fp32: bool
35+
):
36+
num_tests = 20
37+
rank = Utils.rank
38+
world_size = Utils.world_size
39+
for _ in range(num_tests):
40+
# Initialize input tensors.
41+
tensor1 = torch.rand(100000, device='cuda', dtype=torch.bfloat16)
42+
tensor2 = tensor1.clone()
43+
44+
# Make sure the two APIs are *identical*.
45+
kwargs = {"op": torch.distributed.ReduceOp.SUM, "group": None, "async_op": async_op}
46+
47+
# Reduce-scatter with all-to-alls.
48+
args = [
49+
shard_buffer(tensor1, world_size)[rank],
50+
tensor1,
51+
] # Output tensor is view into original input.
52+
handle = reduce_scatter_with_fp32_accumulation(*args, **kwargs)
53+
if async_op:
54+
assert handle is not None
55+
handle.wait()
56+
tensor1_shard = shard_buffer(tensor1, world_size)[rank]
57+
58+
if baseline_reduce_scatter_in_fp32:
59+
tensor2 = tensor2.float()
60+
61+
# Reduce-scatter with reduce-scatter API.
62+
args = [
63+
shard_buffer(tensor2, world_size)[rank],
64+
tensor2,
65+
] # Output tensor is view into original input.
66+
handle = torch.distributed.reduce_scatter_tensor(*args, **kwargs)
67+
if async_op:
68+
assert handle is not None
69+
handle.wait()
70+
tensor2_shard = shard_buffer(tensor2, world_size)[rank]
71+
if baseline_reduce_scatter_in_fp32: # Cast result back to bfloat16.
72+
tensor2_shard = tensor2_shard.bfloat16()
73+
74+
# Compare results: results should match when doing FP32 reduction and not match when
75+
# doing direct BF16 reduction. We only look at relevant shard of tensor1 and tensor2.
76+
assert (
77+
torch.allclose(tensor1_shard, tensor2_shard) == baseline_reduce_scatter_in_fp32
78+
), f"{get_non_matching_values(tensor1_shard, tensor2_shard)}"

0 commit comments

Comments
 (0)