Skip to content

Commit e6715b0

Browse files
committed
return work from manager allreduce
Summary: - returns the work object so we can be more flexible with the usage - had to return a new future from ddp becuase the return type of work.get_future() is incompatible
1 parent 481d47d commit e6715b0

12 files changed

+112
-79
lines changed

torchft/collectives.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AllreduceOptions,
1919
AllToAllOptions,
2020
ReduceScatterOptions,
21+
Work,
2122
)
2223
from torch.futures import Future
2324

@@ -288,7 +289,7 @@ def allreduce_quantized(
288289
opts: AllreduceOptions | ReduceOp,
289290
process_group: "ProcessGroup",
290291
sync_stream: cuda.Stream | None = None,
291-
) -> Future[list[torch.Tensor]]:
292+
) -> Work:
292293
"""
293294
Performs a quantized all-reduce operation on a list of tensors.
294295
@@ -379,6 +380,14 @@ def allreduce_quantized(
379380
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
380381
_to_allgather_options(allreduce_opts),
381382
)
383+
384+
# NOTE: This is not supposed to be used with gloo, only with NCCL.
385+
# So we setup the stream dependency here by calling work.wait(),
386+
# which doesn't block the CPU.
387+
#
388+
# The future callback below will run after the work has been
389+
# completed.
390+
382391
work.wait()
383392
fut = work.get_future()
384393

@@ -394,4 +403,4 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
394403
return tensors
395404

396405
fut = fut.then(callback)
397-
return fut
406+
return work

torchft/collectives_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def _run_all_reduce_collective(
9494
)
9595
]
9696

97-
fut = allreduce_quantized(tensors, reduce_op, pg)
98-
fut.wait()
97+
work = allreduce_quantized(tensors, reduce_op, pg)
98+
work.get_future().wait()
9999

100100
work = pg.allreduce([expected], reduce_op)
101101
work.get_future().wait()

torchft/ddp.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,22 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
6868
def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
71-
return state.allreduce(bucket.buffer())
71+
tensor = bucket.buffer()
72+
work = state.allreduce(tensor)
73+
fut = work.get_future()
74+
75+
result = torch.futures.Future() # pyre-fixme[29]: Future is not a function
76+
77+
def callback(
78+
fut: torch.futures.Future[torch.Tensor],
79+
) -> None:
80+
fut.wait()
81+
fut.value()
82+
result.set_result(tensor)
83+
84+
fut.add_done_callback(callback)
85+
86+
return result
7287

7388

7489
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import torch
1111
import torch.distributed as dist
1212
from torch import nn
13+
from torch.distributed.distributed_c10d import Work
1314
from torch.futures import Future
1415

1516
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
1617
from torchft.manager import Manager
1718
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
19+
from torchft.work import _DummyWork
1820

1921

2022
class TestDDP(TestCase):
@@ -39,14 +41,14 @@ def test_ddp(self) -> None:
3941

4042
call_count = 0
4143

42-
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
44+
def allreduce(
45+
tensor: torch.Tensor,
46+
) -> Work:
4347
nonlocal call_count
4448

4549
call_count += 1
4650

47-
fut = Future() # pyre-fixme[29]: not a function
48-
fut.set_result(tensor)
49-
return fut
51+
return _DummyWork(tensor)
5052

5153
manager.allreduce = allreduce
5254

torchft/local_sgd.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.distributed as dist
2020
from torch import nn, optim
21+
from torch.distributed.distributed_c10d import Work
2122
from torch.distributed.tensor import DTensor
2223
from torch.nn.parameter import Parameter
2324
from torch.optim.optimizer import Optimizer
@@ -154,7 +155,8 @@ def _average(self) -> list[torch.Tensor]:
154155
for p in self._model.parameters():
155156
# Create a new tensor to store the averaged parameter
156157
avg_param = extract_local_tensor(p)
157-
works.append(self._manager.allreduce(avg_param))
158+
work = self._manager.allreduce(avg_param)
159+
works.append(work)
158160
averaged_parameters.append(avg_param)
159161
for work in works:
160162
work.wait()
@@ -200,7 +202,7 @@ def __init__(
200202
self._outer_optimizer = outer_optimizer
201203

202204
# Stores pending all reduce
203-
self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = []
205+
self._allreduce_work: list[Work] = []
204206
self._stream: Optional[torch.cuda.Stream] = (
205207
torch.cuda.Stream() if torch.cuda.is_available() else None
206208
)
@@ -368,15 +370,15 @@ def wait(self) -> None:
368370
"""
369371
Waits for the previously scheduled allreduce to finish
370372
"""
371-
if len(self._allreduce_futures) == 0:
373+
if len(self._allreduce_work) == 0:
372374
return
373375

374376
if self._stream is not None:
375377
assert self._stop_event is not None
376378
self._stop_event.synchronize()
377379
self._stop_event = None
378380

379-
self._allreduce_futures = []
381+
self._allreduce_work = []
380382

381383
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
382384
def prepare_sync(self) -> None:
@@ -386,7 +388,7 @@ def prepare_sync(self) -> None:
386388
"""
387389
self._save_grads()
388390

389-
assert len(self._allreduce_futures) == 0
391+
assert len(self._allreduce_work) == 0
390392

391393
# Make sure tensors are available to `_stream`
392394
if self._stream is not None:
@@ -399,8 +401,8 @@ def prepare_sync(self) -> None:
399401
):
400402
self._average_grads()
401403

402-
for work in self._allreduce_futures:
403-
work.wait()
404+
for work in self._allreduce_work:
405+
work.get_future().wait()
404406

405407
if self._stream is not None:
406408
self._stop_event = torch.cuda.Event()
@@ -413,7 +415,7 @@ def perform_sync(self) -> bool:
413415
steps using the outer optimizer.
414416
"""
415417
# Waiting for an allreduce before it has been sent is currently not supported.
416-
assert len(self._allreduce_futures) > 0
418+
assert len(self._allreduce_work) > 0
417419

418420
self.wait()
419421

@@ -467,7 +469,8 @@ def _allreduce_per_param(self) -> None:
467469
work = self._manager.allreduce(
468470
self._grads[name], should_quantize=self.should_quantize
469471
)
470-
self._allreduce_futures.append(work)
472+
473+
self._allreduce_work.append(work)
471474

472475
def _bucketize_and_allreduce(
473476
self,
@@ -522,8 +525,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
522525
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
523526
)
524527

525-
work = work.then(callback)
526-
self._allreduce_futures.append(work)
528+
fut = work.get_future()
529+
fut = fut.then(callback)
530+
531+
self._allreduce_work.append(work)
527532

528533
offset += chunk_size
529534

torchft/local_sgd_test.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
import torch
1212
from parameterized import parameterized
1313
from torch import Tensor, nn, optim
14+
from torch.distributed.distributed_c10d import Work
1415
from torch.distributed.tensor import DTensor
1516

1617
from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
1718
from torchft.manager import Manager
19+
from torchft.work import _DummyWork
1820

1921

2022
def create_manager() -> MagicMock:
@@ -26,6 +28,11 @@ def create_manager() -> MagicMock:
2628

2729
manager.errored.return_value = None
2830

31+
def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work:
32+
return _DummyWork(tensor)
33+
34+
manager.allreduce.side_effect = mock_allreduce
35+
2936
return manager
3037

3138

@@ -66,7 +73,7 @@ class LocalSGDTest(TestCase):
6673
def test_local_sgd_healthy(self) -> None:
6774
model = SimpleModel()
6875
optimizer = optim.SGD(model.parameters())
69-
manager = create_autospec(Manager)
76+
manager = create_manager()
7077
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
7178
self.assertEqual(local_sgd._local_step, 0)
7279
inp = torch.rand(2, 3)
@@ -240,13 +247,9 @@ def test_bucketization_correctness(self) -> None:
240247
manager.should_commit.return_value = True
241248

242249
# Define fake allreduce: multiplies buffer by 2
243-
def fake_allreduce(
244-
tensor: Tensor, should_quantize: bool
245-
) -> torch.futures.Future[Tensor]:
250+
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
246251
tensor.mul_(2)
247-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
248-
fut.set_result(tensor)
249-
return fut
252+
return _DummyWork(tensor)
250253

251254
manager.allreduce.side_effect = fake_allreduce
252255

@@ -284,13 +287,9 @@ def test_gradient_correctness(self) -> None:
284287
manager.should_commit.return_value = True
285288

286289
# Define fake allreduce: multiplies buffer by 2
287-
def fake_allreduce(
288-
tensor: Tensor, should_quantize: bool
289-
) -> torch.futures.Future[Tensor]:
290+
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
290291
tensor.mul_(2)
291-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
292-
fut.set_result(tensor)
293-
return fut
292+
return _DummyWork(tensor)
294293

295294
manager.allreduce.side_effect = fake_allreduce
296295

torchft/manager.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,12 @@
4040

4141
import torch
4242
from torch.distributed import ReduceOp, TCPStore
43-
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
43+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4444

4545
from torchft._torchft import ManagerClient, ManagerServer
4646
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4747
from torchft.futures import future_timeout
48+
from torchft.work import _DummyWork
4849

4950
if TYPE_CHECKING:
5051
from torchft.process_group import ProcessGroup
@@ -344,9 +345,7 @@ def shutdown(self, wait: bool = True) -> None:
344345
self._executor.shutdown(wait=wait)
345346

346347
@torch.profiler.record_function("torchft::manager::allreduce")
347-
def allreduce(
348-
self, tensor: torch.Tensor, should_quantize: bool = False
349-
) -> torch.futures.Future[torch.Tensor]:
348+
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
350349
"""
351350
Fault tolerant allreduce the tensor and return a Future that will be completed when
352351
the tensor is ready.
@@ -366,9 +365,7 @@ def allreduce(
366365
a Future that will be completed with the allreduced tensor
367366
"""
368367
if self.errored():
369-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
370-
fut.set_result(tensor)
371-
return fut
368+
return _DummyWork(tensor)
372369

373370
self.wait_quorum()
374371
num_participants: int = self.num_participants()
@@ -381,13 +378,14 @@ def allreduce(
381378
# Run the allreduce async and save the work object so we can wait on
382379
# it later.
383380
if should_quantize and IS_TRITON_AVAILABLE:
384-
fut = allreduce_quantized(
381+
work = allreduce_quantized(
385382
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
386383
)
387384
else:
388385
work = self._pg.allreduce([tensor], ReduceOp.SUM)
389386
work.wait()
390-
fut = work.get_future()
387+
388+
fut = work.get_future()
391389

392390
stream: Optional[torch.cuda.Stream] = (
393391
torch.cuda.current_stream() if torch.cuda.is_available() else None
@@ -414,17 +412,15 @@ def callback(
414412
fut = fut.then(callback)
415413

416414
fut = self.wrap_future(fut, tensor)
417-
return fut
415+
return work
418416

419417
except Exception as e:
420418
self._logger.exception(
421419
f"got exception in all reduce -- skipping remaining: {e}"
422420
)
423421
self.report_error(e)
424422

425-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
426-
fut.set_result(tensor)
427-
return fut
423+
return _DummyWork(tensor)
428424

429425
def report_error(self, e: Exception) -> None:
430426
"""

torchft/manager_integ_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def all_reduce_callback(
634634

635635
manager.start_quorum()
636636
t1 = torch.ones((1, 3), device=device)
637-
fut = manager.allreduce(t1)
638-
fut.wait()
637+
work = manager.allreduce(t1)
638+
work.get_future().wait()
639639
return t1
640640
return None

0 commit comments

Comments
 (0)