Skip to content

Commit b9cd277

Browse files
committed
option 2 - call work.wait inside wrapped work
1 parent 405dc6e commit b9cd277

File tree

2 files changed

+59
-46
lines changed

2 files changed

+59
-46
lines changed

torchft/manager.py

Lines changed: 59 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@
3939
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
4040

4141
import torch
42+
import torch.distributed as dist
4243
from torch.distributed import ReduceOp, TCPStore
4344
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4445

4546
from torchft._torchft import ManagerClient, ManagerServer
4647
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4748
from torchft.futures import future_timeout
48-
from torchft.work import _DummyWork, _WorkWrapper
49+
from torchft.work import _DummyWork
4950

5051
if TYPE_CHECKING:
5152
from torchft.process_group import ProcessGroup
@@ -383,37 +384,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
383384
)
384385
else:
385386
work = self._pg.allreduce([tensor], ReduceOp.SUM)
386-
work.wait()
387387

388-
fut = work.get_future()
389-
390-
stream: Optional[torch.cuda.Stream] = (
391-
torch.cuda.current_stream() if torch.cuda.is_available() else None
392-
)
393-
394-
# schedule grad normalization as a continuation
395-
# on the Future
396-
@torch.profiler.record_function("torchft::manager::allreduce::callback")
397-
def callback(
398-
fut: torch.futures.Future[List[torch.Tensor]],
399-
) -> torch.Tensor:
400-
nonlocal tensor, stream, num_participants
401-
402-
# change the stream to avoid making the callback stream
403-
# dependent on process group stream running the allreduce
404-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
405-
# Setup stream dependency
406-
fut.wait()
407-
fut.value()
408-
tensor /= num_participants
409-
410-
return tensor
411-
412-
fut = fut.then(callback)
413-
414-
fut = self.wrap_future(fut, tensor)
415-
416-
return _WorkWrapper(work, fut)
388+
return _WorkWrapper(work, self, tensor, num_participants)
417389

418390
except Exception as e:
419391
self._logger.exception(
@@ -933,3 +905,59 @@ def warn(self, msg: str) -> None:
933905

934906
def exception(self, msg: str) -> None:
935907
self._logger.exception(f"{self.prefix()} {msg}")
908+
909+
910+
class _WorkWrapper(dist._Work):
911+
def __init__(
912+
self,
913+
work: dist._Work,
914+
manager: Manager,
915+
tensor: torch.Tensor,
916+
num_participants: int,
917+
) -> None:
918+
super().__init__()
919+
self._manager = manager
920+
self._work = work
921+
self._tensor = tensor
922+
self._num_participants = num_participants
923+
924+
self._fut = self._work.get_future()
925+
self._stream: Optional[torch.cuda.Stream] = (
926+
torch.cuda.current_stream() if torch.cuda.is_available() else None
927+
)
928+
929+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
930+
with (
931+
torch.cuda.stream(self._stream)
932+
if self._stream is not None
933+
else nullcontext()
934+
):
935+
self._work.wait()
936+
937+
# schedule grad normalization as a continuation
938+
# on the Future
939+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
940+
def callback(
941+
fut: torch.futures.Future[List[torch.Tensor]],
942+
) -> torch.Tensor:
943+
# change the stream to avoid making the callback stream
944+
# dependent on process group stream running the allreduce
945+
with (
946+
torch.cuda.stream(self._stream)
947+
if self._stream is not None
948+
else nullcontext()
949+
):
950+
# Setup stream dependency
951+
fut.wait()
952+
self._tensor /= self._num_participants
953+
954+
return self._tensor
955+
956+
self._fut = self._fut.then(callback)
957+
self._fut = self._manager.wrap_future(self._fut, self._tensor)
958+
959+
return True
960+
961+
def get_future(self) -> torch.futures.Future[torch.Tensor]:
962+
self.wait()
963+
return self._fut

torchft/work.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1919

2020
def get_future(self) -> torch.futures.Future[object]:
2121
return self.future_
22-
23-
24-
class _WorkWrapper(dist._Work):
25-
def __init__(
26-
self, work: dist._Work, fut: torch.futures.Future[torch.Tensor]
27-
) -> None:
28-
super().__init__()
29-
self._work = work
30-
self._fut = fut
31-
32-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
33-
return True
34-
35-
def get_future(self) -> torch.futures.Future[torch.Tensor]:
36-
return self._fut

0 commit comments

Comments
 (0)