From fef4abc8772d34b9a50a8e7a4d3d37b2c1a8009e Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 25 Jul 2025 18:30:02 -0700 Subject: [PATCH 1/3] use http transport Summary: use http transport instead of pg transport -- pg transport fails to resolve address when running locally --- train_diloco.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train_diloco.py b/train_diloco.py index 0c6b9cf..e207e73 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -34,7 +34,7 @@ ProcessGroupGloo, ProcessGroupNCCL, ) -from torchft.checkpointing.pg_transport import PGTransport +from torchft.checkpointing.http_transport import HTTPTransport from torchft.local_sgd import DiLoCo logging.basicConfig(level=logging.INFO) @@ -67,13 +67,12 @@ def state_dict(): timeout=timedelta(seconds=10), ) if torch.cuda.is_available() and USE_NCCL - else ProcessGroupGloo(timeout=timedelta(seconds=5)) + else ProcessGroupGloo(timeout=timedelta(seconds=10)) ) - transport = PGTransport( - pg, + transport = HTTPTransport( timeout=timedelta(seconds=10), - device=device, + num_chunks=0, ) manager = Manager( From 09bbdea7ca569378f6ac0f9a39161bb9b6421bbb Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 31 Jul 2025 18:55:23 -0700 Subject: [PATCH 2/3] fix stream dependencies in callbacks Summary: - call future.wait in callbacks to make sure the continuation executes after the future has completed - set the stream correctly to execute callback scheduled by bucketized allreduce --- torchft/collectives.py | 2 ++ torchft/local_sgd.py | 11 ++++++++--- torchft/manager.py | 2 ++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/torchft/collectives.py b/torchft/collectives.py index 837fbcd..927309a 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -387,6 +387,8 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]: nonlocal tensors, quantized_tensors, world_size, sync_stream with torch.cuda.stream(sync_stream): + # Setup stream dependency + fut.wait() # Dequantize the result back to the original precision fused_dequantize_from_fp8(tensors, quantized_tensors, world_size) return tensors diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index c7230ee..d0eeccc 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -513,9 +513,14 @@ def _bucketize_and_allreduce( ) def callback(fut: torch.futures.Future[torch.Tensor]) -> None: - nonlocal bucket_tensors, flat_buffer - for t, pack_offset, numel in bucket_tensors: - t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t)) + with torch.cuda.stream(self._stream) if self._stream else nullcontext(): + nonlocal bucket_tensors, flat_buffer + # Setup stream dependency + fut.wait() + for t, pack_offset, numel in bucket_tensors: + t.copy_( + flat_buffer[pack_offset : pack_offset + numel].view_as(t) + ) work = work.then(callback) self._allreduce_futures.append(work) diff --git a/torchft/manager.py b/torchft/manager.py index e01a965..09100c3 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -403,6 +403,8 @@ def callback( # change the stream to avoid making the callback stream # dependent on process group stream running the allreduce with torch.cuda.stream(stream) if stream is not None else nullcontext(): + # Setup stream dependency + fut.wait() fut.value() tensor /= num_participants From 9683ef4db489813a5811ff6a8266372945b33b96 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 31 Jul 2025 18:55:23 -0700 Subject: [PATCH 3/3] return work from manager allreduce Summary: returns the work object so we can be more flexible with the usage --- torchft/collectives.py | 18 ++++++++++++----- torchft/collectives_test.py | 4 ++-- torchft/ddp.py | 3 ++- torchft/ddp_test.py | 10 ++++++---- torchft/local_sgd.py | 22 ++++++++++++--------- torchft/local_sgd_test.py | 25 ++++++++++++----------- torchft/manager.py | 23 ++++++++++------------ torchft/manager_integ_test.py | 4 ++-- torchft/manager_test.py | 11 ++++++----- torchft/process_group.py | 16 +-------------- torchft/process_group_test.py | 2 +- torchft/work.py | 37 +++++++++++++++++++++++++++++++++++ 12 files changed, 105 insertions(+), 70 deletions(-) create mode 100644 torchft/work.py diff --git a/torchft/collectives.py b/torchft/collectives.py index 927309a..af95cbb 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -18,6 +18,7 @@ AllreduceOptions, AllToAllOptions, ReduceScatterOptions, + Work, ) from torch.futures import Future @@ -288,7 +289,7 @@ def allreduce_quantized( opts: AllreduceOptions | ReduceOp, process_group: "ProcessGroup", sync_stream: cuda.Stream | None = None, -) -> Future[list[torch.Tensor]]: +) -> Work: """ Performs a quantized all-reduce operation on a list of tensors. @@ -379,10 +380,18 @@ def allreduce_quantized( [torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]], _to_allgather_options(allreduce_opts), ) + + # NOTE: This is not supposed to be used with gloo, only with NCCL. + # So we setup the stream dependency here by calling work.wait(), + # which doesn't block the CPU. + # + # The future callback below will run after the work has been + # completed. + work.wait() fut = work.get_future() - def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]: + def callback(fut: Future[list[torch.Tensor]]) -> None: # Dequantize and copy to output buffer. nonlocal tensors, quantized_tensors, world_size, sync_stream @@ -391,7 +400,6 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]: fut.wait() # Dequantize the result back to the original precision fused_dequantize_from_fp8(tensors, quantized_tensors, world_size) - return tensors - fut = fut.then(callback) - return fut + fut.add_done_callback(callback) + return work diff --git a/torchft/collectives_test.py b/torchft/collectives_test.py index c4b826b..6660abe 100644 --- a/torchft/collectives_test.py +++ b/torchft/collectives_test.py @@ -94,8 +94,8 @@ def _run_all_reduce_collective( ) ] - fut = allreduce_quantized(tensors, reduce_op, pg) - fut.wait() + work = allreduce_quantized(tensors, reduce_op, pg) + work.wait() work = pg.allreduce([expected], reduce_op) work.get_future().wait() diff --git a/torchft/ddp.py b/torchft/ddp.py index 6fbea8f..1355317 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -68,7 +68,8 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N def _comm_hook( state: "Manager", bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: - return state.allreduce(bucket.buffer()) + work = state.allreduce(bucket.buffer()) + return work.get_future() class PureDistributedDataParallel(nn.Module): diff --git a/torchft/ddp_test.py b/torchft/ddp_test.py index 1a56dce..690bfd0 100644 --- a/torchft/ddp_test.py +++ b/torchft/ddp_test.py @@ -10,11 +10,13 @@ import torch import torch.distributed as dist from torch import nn +from torch.distributed.distributed_c10d import Work from torch.futures import Future from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel from torchft.manager import Manager from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo +from torchft.work import _DummyWork class TestDDP(TestCase): @@ -39,14 +41,14 @@ def test_ddp(self) -> None: call_count = 0 - def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]: + def allreduce( + tensor: torch.Tensor, + ) -> Work: nonlocal call_count call_count += 1 - fut = Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) manager.allreduce = allreduce diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index d0eeccc..761a74c 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -18,6 +18,7 @@ import torch import torch.distributed as dist from torch import nn, optim +from torch.distributed.distributed_c10d import Work from torch.distributed.tensor import DTensor from torch.nn.parameter import Parameter from torch.optim.optimizer import Optimizer @@ -200,7 +201,7 @@ def __init__( self._outer_optimizer = outer_optimizer # Stores pending all reduce - self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = [] + self._allreduce_work: list[Work] = [] self._stream: Optional[torch.cuda.Stream] = ( torch.cuda.Stream() if torch.cuda.is_available() else None ) @@ -368,7 +369,7 @@ def wait(self) -> None: """ Waits for the previously scheduled allreduce to finish """ - if len(self._allreduce_futures) == 0: + if len(self._allreduce_work) == 0: return if self._stream is not None: @@ -376,7 +377,7 @@ def wait(self) -> None: self._stop_event.synchronize() self._stop_event = None - self._allreduce_futures = [] + self._allreduce_work = [] @torch.profiler.record_function("torchft::local_sgd::prepare_sync") def prepare_sync(self) -> None: @@ -386,7 +387,7 @@ def prepare_sync(self) -> None: """ self._save_grads() - assert len(self._allreduce_futures) == 0 + assert len(self._allreduce_work) == 0 # Make sure tensors are available to `_stream` if self._stream is not None: @@ -399,7 +400,7 @@ def prepare_sync(self) -> None: ): self._average_grads() - for work in self._allreduce_futures: + for work in self._allreduce_work: work.wait() if self._stream is not None: @@ -413,7 +414,7 @@ def perform_sync(self) -> bool: steps using the outer optimizer. """ # Waiting for an allreduce before it has been sent is currently not supported. - assert len(self._allreduce_futures) > 0 + assert len(self._allreduce_work) > 0 self.wait() @@ -467,7 +468,8 @@ def _allreduce_per_param(self) -> None: work = self._manager.allreduce( self._grads[name], should_quantize=self.should_quantize ) - self._allreduce_futures.append(work) + + self._allreduce_work.append(work) def _bucketize_and_allreduce( self, @@ -522,8 +524,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None: flat_buffer[pack_offset : pack_offset + numel].view_as(t) ) - work = work.then(callback) - self._allreduce_futures.append(work) + fut = work.get_future() + fut.add_done_callback(callback) + + self._allreduce_work.append(work) offset += chunk_size diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index 04aede4..881b96e 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -11,10 +11,12 @@ import torch from parameterized import parameterized from torch import Tensor, nn, optim +from torch.distributed.distributed_c10d import Work from torch.distributed.tensor import DTensor from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor from torchft.manager import Manager +from torchft.work import _DummyWork def create_manager() -> MagicMock: @@ -26,6 +28,11 @@ def create_manager() -> MagicMock: manager.errored.return_value = None + def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work: + return _DummyWork(tensor) + + manager.allreduce.side_effect = mock_allreduce + return manager @@ -66,7 +73,7 @@ class LocalSGDTest(TestCase): def test_local_sgd_healthy(self) -> None: model = SimpleModel() optimizer = optim.SGD(model.parameters()) - manager = create_autospec(Manager) + manager = create_manager() with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd: self.assertEqual(local_sgd._local_step, 0) inp = torch.rand(2, 3) @@ -240,13 +247,9 @@ def test_bucketization_correctness(self) -> None: manager.should_commit.return_value = True # Define fake allreduce: multiplies buffer by 2 - def fake_allreduce( - tensor: Tensor, should_quantize: bool - ) -> torch.futures.Future[Tensor]: + def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work: tensor.mul_(2) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) manager.allreduce.side_effect = fake_allreduce @@ -284,13 +287,9 @@ def test_gradient_correctness(self) -> None: manager.should_commit.return_value = True # Define fake allreduce: multiplies buffer by 2 - def fake_allreduce( - tensor: Tensor, should_quantize: bool - ) -> torch.futures.Future[Tensor]: + def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work: tensor.mul_(2) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) manager.allreduce.side_effect = fake_allreduce diff --git a/torchft/manager.py b/torchft/manager.py index 09100c3..0b6e63b 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -39,11 +39,12 @@ import torch from torch.distributed import ReduceOp, TCPStore -from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp +from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work from torchft._torchft import ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport from torchft.futures import future_timeout +from torchft.work import _DummyWork, _WorkWrapper if TYPE_CHECKING: from torchft.process_group import ProcessGroup @@ -343,9 +344,7 @@ def shutdown(self, wait: bool = True) -> None: self._executor.shutdown(wait=wait) @torch.profiler.record_function("torchft::manager::allreduce") - def allreduce( - self, tensor: torch.Tensor, should_quantize: bool = False - ) -> torch.futures.Future[torch.Tensor]: + def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work: """ Fault tolerant allreduce the tensor and return a Future that will be completed when the tensor is ready. @@ -365,9 +364,7 @@ def allreduce( a Future that will be completed with the allreduced tensor """ if self.errored(): - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) self.wait_quorum() num_participants: int = self.num_participants() @@ -380,13 +377,14 @@ def allreduce( # Run the allreduce async and save the work object so we can wait on # it later. if should_quantize and IS_TRITON_AVAILABLE: - fut = allreduce_quantized( + work = allreduce_quantized( [tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream() ) else: work = self._pg.allreduce([tensor], ReduceOp.SUM) work.wait() - fut = work.get_future() + + fut = work.get_future() stream: Optional[torch.cuda.Stream] = ( torch.cuda.current_stream() if torch.cuda.is_available() else None @@ -413,7 +411,8 @@ def callback( fut = fut.then(callback) fut = self.wrap_future(fut, tensor) - return fut + + return _WorkWrapper(work, fut) except Exception as e: self._logger.exception( @@ -421,9 +420,7 @@ def callback( ) self.report_error(e) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) def report_error(self, e: Exception) -> None: """ diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 6bdab58..ed2d11e 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -634,7 +634,7 @@ def all_reduce_callback( manager.start_quorum() t1 = torch.ones((1, 3), device=device) - fut = manager.allreduce(t1) - fut.wait() + work = manager.allreduce(t1) + work.wait() return t1 return None diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 3140319..d0c81a0 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -16,7 +16,8 @@ from torchft._torchft import QuorumResult from torchft.checkpointing.transport import CheckpointTransport from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode -from torchft.process_group import ProcessGroup, _DummyWork +from torchft.process_group import ProcessGroup +from torchft.work import _DummyWork def mock_should_commit( @@ -586,16 +587,16 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: manager._pg.allreduce.return_value = _DummyWork(None) self.assertTrue(manager.is_participating()) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut = manager.allreduce(torch.tensor([1.0])) + work = manager.allreduce(torch.tensor([1.0])) + fut = work.get_future() result = fut.value() torch.testing.assert_close(result, torch.tensor([1.0 / 5])) # check healing numerics manager._healing = True self.assertFalse(manager.is_participating()) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut = manager.allreduce(torch.tensor([1.0])) + work = manager.allreduce(torch.tensor([1.0])) + fut = work.get_future() result = fut.value() torch.testing.assert_close(result, torch.tensor([0.0])) diff --git a/torchft/process_group.py b/torchft/process_group.py index 8f9c27b..4750dc9 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -69,6 +69,7 @@ from torchft.device_mesh import * # noqa: F401 from torchft.futures import context_timeout, stream_timeout from torchft.multiprocessing import _MonitoredPipe +from torchft.work import _DummyWork if TYPE_CHECKING: from torchft.manager import Manager @@ -790,21 +791,6 @@ def getBackendName(self) -> str: return "torchft-nccl" -class _DummyWork(dist._Work): - def __init__(self, result: object) -> None: - super().__init__() - self.result_ = result - # pyre-fixme[29]: Future is not a function - self.future_: torch.futures.Future[object] = torch.futures.Future() - self.future_.set_result(result) - - def wait(self, timeout: Optional[timedelta] = None) -> bool: - return True - - def get_future(self) -> torch.futures.Future[object]: - return self.future_ - - class ProcessGroupDummy(ProcessGroup): """ This process group discards all data passed to it and returns success. This diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 4c3455d..072cf1e 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -47,12 +47,12 @@ ProcessGroupGloo, ProcessGroupNCCL, ProcessGroupWrapper, - _DummyWork, _ErrorSwallowingWork, _ManagedWork, extend_device_mesh, ft_init_device_mesh, ) +from torchft.work import _DummyWork def dummy_init_pg() -> None: diff --git a/torchft/work.py b/torchft/work.py new file mode 100644 index 0000000..7211c0d --- /dev/null +++ b/torchft/work.py @@ -0,0 +1,37 @@ +from contextlib import nullcontext +from datetime import timedelta +from typing import Optional + +import torch +import torch.distributed as dist + + +class _DummyWork(dist._Work): + def __init__(self, result: object) -> None: + super().__init__() + self.result_ = result + # pyre-fixme[29]: Future is not a function + self.future_: torch.futures.Future[object] = torch.futures.Future() + self.future_.set_result(result) + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + return True + + def get_future(self) -> torch.futures.Future[object]: + return self.future_ + + +class _WorkWrapper(dist._Work): + def __init__( + self, work: dist._Work, fut: torch.futures.Future[torch.Tensor] + ) -> None: + super().__init__() + self._work = work + self._fut = fut + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + self._fut.wait() + return True + + def get_future(self) -> torch.futures.Future[torch.Tensor]: + return self._fut