Skip to content

Commit c93ad11

Browse files
committed
wait for futures while syncing fragments
Summary: - we current wait for pg work's future when preparing for a fragment - if we use gloo, this blocks the cpu - move the wait call to when we perform the actual sync of the fragment - the manager allreduce also returns the work object so we can wait for that as well when performing the sync - use http transport instead of pg transport -- pg transport fails to resolve address when running locally - deep copy the state dict for sending checkpoint because if the replica moves to the next step, the state dict can change before the checkpoint is sent Test Plan: gloo overlaps now <img width="1284" height="662" alt="image" src="https://github.com/user-attachments/assets/e9b88e52-8053-432b-83a3-e689bcc4f9d4" /> nccl still overlaps <img width="1283" height="664" alt="image" src="https://github.com/user-attachments/assets/cbd0a352-1529-42f7-b8d9-d45bd0e84a97" />
1 parent 949a981 commit c93ad11

File tree

10 files changed

+84
-52
lines changed

10 files changed

+84
-52
lines changed

torchft/collectives.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AllreduceOptions,
1919
AllToAllOptions,
2020
ReduceScatterOptions,
21+
Work,
2122
)
2223
from torch.futures import Future
2324

@@ -288,7 +289,7 @@ def allreduce_quantized(
288289
opts: AllreduceOptions | ReduceOp,
289290
process_group: "ProcessGroup",
290291
sync_stream: cuda.Stream | None = None,
291-
) -> Future[list[torch.Tensor]]:
292+
) -> tuple[Work, Future[list[torch.Tensor]]]:
292293
"""
293294
Performs a quantized all-reduce operation on a list of tensors.
294295
@@ -379,17 +380,17 @@ def allreduce_quantized(
379380
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
380381
_to_allgather_options(allreduce_opts),
381382
)
382-
work.wait()
383383
fut = work.get_future()
384384

385385
def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
386386
# Dequantize and copy to output buffer.
387387
nonlocal tensors, quantized_tensors, world_size, sync_stream
388388

389389
with torch.cuda.stream(sync_stream):
390+
fut.wait()
390391
# Dequantize the result back to the original precision
391392
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
392393
return tensors
393394

394395
fut = fut.then(callback)
395-
return fut
396+
return (work, fut)

torchft/collectives_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _run_all_reduce_collective(
9494
)
9595
]
9696

97-
fut = allreduce_quantized(tensors, reduce_op, pg)
97+
_, fut = allreduce_quantized(tensors, reduce_op, pg)
9898
fut.wait()
9999

100100
work = pg.allreduce([expected], reduce_op)

torchft/ddp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
6868
def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
71-
return state.allreduce(bucket.buffer())
71+
_, fut = state.allreduce(bucket.buffer())
72+
return fut
7273

7374

7475
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ def test_ddp(self) -> None:
3939

4040
call_count = 0
4141

42-
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
42+
def allreduce(
43+
tensor: torch.Tensor,
44+
) -> tuple[torch.Tensor, Future[torch.Tensor]]:
4345
nonlocal call_count
4446

4547
call_count += 1
4648

4749
fut = Future() # pyre-fixme[29]: not a function
4850
fut.set_result(tensor)
49-
return fut
51+
return tensor, fut
5052

5153
manager.allreduce = allreduce
5254

torchft/local_sgd.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.distributed as dist
2020
from torch import nn, optim
21+
from torch.distributed.distributed_c10d import Work
2122
from torch.distributed.tensor import DTensor
2223
from torch.nn.parameter import Parameter
2324
from torch.optim.optimizer import Optimizer
@@ -154,7 +155,8 @@ def _average(self) -> list[torch.Tensor]:
154155
for p in self._model.parameters():
155156
# Create a new tensor to store the averaged parameter
156157
avg_param = extract_local_tensor(p)
157-
works.append(self._manager.allreduce(avg_param))
158+
(work, fut) = self._manager.allreduce(avg_param)
159+
works.append(fut)
158160
averaged_parameters.append(avg_param)
159161
for work in works:
160162
work.wait()
@@ -201,6 +203,7 @@ def __init__(
201203

202204
# Stores pending all reduce
203205
self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = []
206+
self._allreduce_work: list[Work] = []
204207
self._stream: Optional[torch.cuda.Stream] = (
205208
torch.cuda.Stream() if torch.cuda.is_available() else None
206209
)
@@ -377,6 +380,7 @@ def wait(self) -> None:
377380
self._stop_event = None
378381

379382
self._allreduce_futures = []
383+
self._allreduce_work = []
380384

381385
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
382386
def prepare_sync(self) -> None:
@@ -399,13 +403,6 @@ def prepare_sync(self) -> None:
399403
):
400404
self._average_grads()
401405

402-
for work in self._allreduce_futures:
403-
work.wait()
404-
405-
if self._stream is not None:
406-
self._stop_event = torch.cuda.Event()
407-
self._stop_event.record()
408-
409406
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
410407
def perform_sync(self) -> bool:
411408
"""
@@ -415,6 +412,21 @@ def perform_sync(self) -> bool:
415412
# Waiting for an allreduce before it has been sent is currently not supported.
416413
assert len(self._allreduce_futures) > 0
417414

415+
with (
416+
torch.cuda.stream(self._stream)
417+
if self._stream is not None
418+
else nullcontext()
419+
):
420+
for work in self._allreduce_work:
421+
work.wait()
422+
423+
for fut in self._allreduce_futures:
424+
fut.wait()
425+
426+
if self._stream is not None:
427+
self._stop_event = torch.cuda.Event()
428+
self._stop_event.record()
429+
418430
self.wait()
419431

420432
# save the parameters so they can be used for merging
@@ -464,10 +476,13 @@ def _allreduce_per_param(self) -> None:
464476
"""Performs allreduce on each gradient tensor separately (original method)."""
465477
for name, p in self._model_fragment.named_parameters():
466478
# Perform allreduce on the pseudogradients
467-
work = self._manager.allreduce(
479+
(work, fut) = self._manager.allreduce(
468480
self._grads[name], should_quantize=self.should_quantize
469481
)
470-
self._allreduce_futures.append(work)
482+
self._allreduce_futures.append(fut)
483+
484+
if work is not None:
485+
self._allreduce_work.append(work)
471486

472487
def _bucketize_and_allreduce(
473488
self,
@@ -508,7 +523,7 @@ def _bucketize_and_allreduce(
508523
pack_offset += numel
509524
flat_index += 1
510525

511-
work = self._manager.allreduce(
526+
(work, fut) = self._manager.allreduce(
512527
flat_buffer, should_quantize=self.should_quantize
513528
)
514529

@@ -517,8 +532,11 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
517532
for t, pack_offset, numel in bucket_tensors:
518533
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
519534

520-
work = work.then(callback)
521-
self._allreduce_futures.append(work)
535+
fut = fut.then(callback)
536+
537+
self._allreduce_futures.append(fut)
538+
if work is not None:
539+
self._allreduce_work.append(work)
522540

523541
offset += chunk_size
524542

torchft/local_sgd_test.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from parameterized import parameterized
1313
from torch import Tensor, nn, optim
14+
from torch.distributed.distributed_c10d import Work
1415
from torch.distributed.tensor import DTensor
1516

1617
from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
@@ -26,6 +27,15 @@ def create_manager() -> MagicMock:
2627

2728
manager.errored.return_value = None
2829

30+
def mock_allreduce(
31+
tensor: torch.Tensor, should_quantize: bool = False
32+
) -> tuple[Work | None, torch.futures.Future[Tensor]]:
33+
fut = torch.futures.Future() # pyre-fixme[29]: not a function
34+
fut.set_result(tensor)
35+
return (None, fut)
36+
37+
manager.allreduce.side_effect = mock_allreduce
38+
2939
return manager
3040

3141

@@ -66,7 +76,7 @@ class LocalSGDTest(TestCase):
6676
def test_local_sgd_healthy(self) -> None:
6777
model = SimpleModel()
6878
optimizer = optim.SGD(model.parameters())
69-
manager = create_autospec(Manager)
79+
manager = create_manager()
7080
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
7181
self.assertEqual(local_sgd._local_step, 0)
7282
inp = torch.rand(2, 3)
@@ -242,11 +252,11 @@ def test_bucketization_correctness(self) -> None:
242252
# Define fake allreduce: multiplies buffer by 2
243253
def fake_allreduce(
244254
tensor: Tensor, should_quantize: bool
245-
) -> torch.futures.Future[Tensor]:
255+
) -> tuple[Work | None, torch.futures.Future[Tensor]]:
246256
tensor.mul_(2)
247257
fut = torch.futures.Future() # pyre-fixme[29]: not a function
248258
fut.set_result(tensor)
249-
return fut
259+
return (None, fut)
250260

251261
manager.allreduce.side_effect = fake_allreduce
252262

@@ -286,11 +296,11 @@ def test_gradient_correctness(self) -> None:
286296
# Define fake allreduce: multiplies buffer by 2
287297
def fake_allreduce(
288298
tensor: Tensor, should_quantize: bool
289-
) -> torch.futures.Future[Tensor]:
299+
) -> tuple[Work | None, torch.futures.Future[Tensor]]:
290300
tensor.mul_(2)
291301
fut = torch.futures.Future() # pyre-fixme[29]: not a function
292302
fut.set_result(tensor)
293-
return fut
303+
return (None, fut)
294304

295305
manager.allreduce.side_effect = fake_allreduce
296306

torchft/manager.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"""
2727

2828
import concurrent.futures
29+
import copy
2930
import logging
3031
import os
3132
import socket
@@ -39,7 +40,7 @@
3940

4041
import torch
4142
from torch.distributed import ReduceOp, TCPStore
42-
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
43+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4344

4445
from torchft._torchft import ManagerClient, ManagerServer
4546
from torchft.checkpointing import CheckpointTransport, HTTPTransport
@@ -345,7 +346,7 @@ def shutdown(self, wait: bool = True) -> None:
345346
@torch.profiler.record_function("torchft::manager::allreduce")
346347
def allreduce(
347348
self, tensor: torch.Tensor, should_quantize: bool = False
348-
) -> torch.futures.Future[torch.Tensor]:
349+
) -> tuple[Work | None, torch.futures.Future[torch.Tensor]]:
349350
"""
350351
Fault tolerant allreduce the tensor and return a Future that will be completed when
351352
the tensor is ready.
@@ -367,7 +368,7 @@ def allreduce(
367368
if self.errored():
368369
fut = torch.futures.Future() # pyre-fixme[29]: not a function
369370
fut.set_result(tensor)
370-
return fut
371+
return (None, fut)
371372

372373
self.wait_quorum()
373374
num_participants: int = self.num_participants()
@@ -380,12 +381,11 @@ def allreduce(
380381
# Run the allreduce async and save the work object so we can wait on
381382
# it later.
382383
if should_quantize and IS_TRITON_AVAILABLE:
383-
fut = allreduce_quantized(
384+
(work, fut) = allreduce_quantized(
384385
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
385386
)
386387
else:
387388
work = self._pg.allreduce([tensor], ReduceOp.SUM)
388-
work.wait()
389389
fut = work.get_future()
390390

391391
stream: Optional[torch.cuda.Stream] = (
@@ -403,6 +403,7 @@ def callback(
403403
# change the stream to avoid making the callback stream
404404
# dependent on process group stream running the allreduce
405405
with torch.cuda.stream(stream) if stream is not None else nullcontext():
406+
fut.wait()
406407
fut.value()
407408
tensor /= num_participants
408409

@@ -411,7 +412,7 @@ def callback(
411412
fut = fut.then(callback)
412413

413414
fut = self.wrap_future(fut, tensor)
414-
return fut
415+
return (work, fut)
415416

416417
except Exception as e:
417418
self._logger.exception(
@@ -421,7 +422,7 @@ def callback(
421422

422423
fut = torch.futures.Future() # pyre-fixme[29]: not a function
423424
fut.set_result(tensor)
424-
return fut
425+
return (None, fut)
425426

426427
def report_error(self, e: Exception) -> None:
427428
"""
@@ -646,7 +647,7 @@ def _async_quorum(
646647
self._checkpoint_transport.send_checkpoint(
647648
dst_ranks=quorum.recover_dst_replica_ranks,
648649
step=max_step,
649-
state_dict=self._manager_state_dict(),
650+
state_dict=copy.deepcopy(self._manager_state_dict()),
650651
timeout=self._timeout,
651652
)
652653

torchft/manager_integ_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def all_reduce_callback(
634634

635635
manager.start_quorum()
636636
t1 = torch.ones((1, 3), device=device)
637-
fut = manager.allreduce(t1)
637+
(_, fut) = manager.allreduce(t1)
638638
fut.wait()
639639
return t1
640640
return None

0 commit comments

Comments
 (0)