diff --git a/torchft/ddp.py b/torchft/ddp.py index 1355317..494a9b1 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -69,6 +69,7 @@ def _comm_hook( state: "Manager", bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: work = state.allreduce(bucket.buffer()) + work.synchronize() return work.get_future() diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 761a74c..957680e 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -524,6 +524,7 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None: flat_buffer[pack_offset : pack_offset + numel].view_as(t) ) + work.synchronize() fut = work.get_future() fut.add_done_callback(callback) diff --git a/torchft/manager.py b/torchft/manager.py index 0b6e63b..c49f839 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -35,16 +35,28 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + TypeAlias, + TypeVar, + Union, + cast, +) import torch +import torch.distributed as dist from torch.distributed import ReduceOp, TCPStore from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work from torchft._torchft import ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport from torchft.futures import future_timeout -from torchft.work import _DummyWork, _WorkWrapper +from torchft.work import _DummyWork if TYPE_CHECKING: from torchft.process_group import ProcessGroup @@ -344,7 +356,11 @@ def shutdown(self, wait: bool = True) -> None: self._executor.shutdown(wait=wait) @torch.profiler.record_function("torchft::manager::allreduce") - def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work: + def allreduce( + self, + tensor: torch.Tensor, + should_quantize: bool = False, + ) -> Work: """ Fault tolerant allreduce the tensor and return a Future that will be completed when the tensor is ready. @@ -382,37 +398,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work ) else: work = self._pg.allreduce([tensor], ReduceOp.SUM) - work.wait() - - fut = work.get_future() - - stream: Optional[torch.cuda.Stream] = ( - torch.cuda.current_stream() if torch.cuda.is_available() else None - ) - - # schedule grad normalization as a continuation - # on the Future - @torch.profiler.record_function("torchft::manager::allreduce::callback") - def callback( - fut: torch.futures.Future[List[torch.Tensor]], - ) -> torch.Tensor: - nonlocal tensor, stream, num_participants - - # change the stream to avoid making the callback stream - # dependent on process group stream running the allreduce - with torch.cuda.stream(stream) if stream is not None else nullcontext(): - # Setup stream dependency - fut.wait() - fut.value() - tensor /= num_participants - return tensor - - fut = fut.then(callback) - - fut = self.wrap_future(fut, tensor) - - return _WorkWrapper(work, fut) + return _ManagedWork(work, self, tensor, num_participants) except Exception as e: self._logger.exception( @@ -932,3 +919,103 @@ def warn(self, msg: str) -> None: def exception(self, msg: str) -> None: self._logger.exception(f"{self.prefix()} {msg}") + + +class _ManagedWork(dist._Work): + def __init__( + self, + work: dist._Work, + manager: Manager, + tensor: torch.Tensor, + num_participants: int, + ) -> None: + super().__init__() + self._manager = manager + self._work = work + self._tensor = tensor + self._num_participants = num_participants + self._fut: Union[ + torch.futures.Future[torch.Tensor], torch.futures.Future[None] + ] = work.get_future() + + self._stream: Optional[torch.cuda.Stream] = ( + torch.cuda.current_stream() if torch.cuda.is_available() else None + ) + + self._is_set_future_callback_called = False + + def _set_future_callback( + self, + ) -> None: + if self._is_set_future_callback_called: + return + + # schedule grad normalization as a continuation + # on the Future + @torch.profiler.record_function("torchft::manager::allreduce::callback") + def callback( + fut: torch.futures.Future[List[torch.Tensor]], + ) -> torch.Tensor: + # change the stream to avoid making the callback stream + # dependent on process group stream running the allreduce + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + # Setup stream dependency + fut.wait() + self._tensor /= self._num_participants + + return self._tensor + + fut = self._fut + fut = fut.then(callback) + fut = self._manager.wrap_future(fut, self._tensor) + self._fut = fut + + self._is_set_future_callback_called = True + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + self._work.wait() + + self._set_future_callback() + + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + self._fut.wait() + + return True + + def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + self._work.block_current_stream() + + self._set_future_callback() + + def synchronize(self) -> None: + if torch.cuda.is_available(): + self.block_current_stream() + else: + # No stream dependencies need to be set + self._set_future_callback() + + def get_future( + self, + ) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]: + assert ( + self._is_set_future_callback_called + ), "getting the future without calling synchronize() is unsafe" + return self._fut diff --git a/torchft/manager_test.py b/torchft/manager_test.py index b5616bb..2a6ec29 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -588,6 +588,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: self.assertTrue(manager.is_participating()) work = manager.allreduce(torch.tensor([1.0])) + work.synchronize() fut = work.get_future() result = fut.value() torch.testing.assert_close(result, torch.tensor([1.0 / 5])) @@ -596,6 +597,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: manager._healing = True self.assertFalse(manager.is_participating()) work = manager.allreduce(torch.tensor([1.0])) + work.synchronize() fut = work.get_future() result = fut.value() torch.testing.assert_close(result, torch.tensor([0.0])) diff --git a/torchft/work.py b/torchft/work.py index 8cb056a..9c6ff03 100644 --- a/torchft/work.py +++ b/torchft/work.py @@ -18,19 +18,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool: def get_future(self) -> torch.futures.Future[object]: return self.future_ - - -class _WorkWrapper(dist._Work): - def __init__( - self, work: dist._Work, fut: torch.futures.Future[torch.Tensor] - ) -> None: - super().__init__() - self._work = work - self._fut = fut - - def wait(self, timeout: Optional[timedelta] = None) -> bool: - self._fut.wait() - return True - - def get_future(self) -> torch.futures.Future[torch.Tensor]: - return self._fut