Skip to content

Commit b16e0cb

Browse files
committed
setup stream dependencies inside work wrapper
Summary: - extend the work wrapper object to also do the division post allreduce - add api to block_current_stream on work wrapper so it can be used for HSDP
1 parent cb39d98 commit b16e0cb

File tree

2 files changed

+109
-50
lines changed

2 files changed

+109
-50
lines changed

torchft/manager.py

Lines changed: 109 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,28 @@
3535
from contextlib import nullcontext
3636
from datetime import timedelta
3737
from enum import Enum
38-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
38+
from typing import (
39+
TYPE_CHECKING,
40+
Any,
41+
Callable,
42+
Dict,
43+
List,
44+
Optional,
45+
TypeAlias,
46+
TypeVar,
47+
Union,
48+
cast,
49+
)
3950

4051
import torch
52+
import torch.distributed as dist
4153
from torch.distributed import ReduceOp, TCPStore
4254
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4355

4456
from torchft._torchft import ManagerClient, ManagerServer
4557
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4658
from torchft.futures import future_timeout
47-
from torchft.work import _DummyWork, _WorkWrapper
59+
from torchft.work import _DummyWork
4860

4961
if TYPE_CHECKING:
5062
from torchft.process_group import ProcessGroup
@@ -350,7 +362,11 @@ def shutdown(self, wait: bool = True) -> None:
350362
self._executor.shutdown(wait=wait)
351363

352364
@torch.profiler.record_function("torchft::manager::allreduce")
353-
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
365+
def allreduce(
366+
self,
367+
tensor: torch.Tensor,
368+
should_quantize: bool = False,
369+
) -> Work:
354370
"""
355371
Fault tolerant allreduce the tensor and return a Future that will be completed when
356372
the tensor is ready.
@@ -388,38 +404,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
388404
)
389405
else:
390406
work = self._pg.allreduce([tensor], ReduceOp.SUM)
391-
if torch.cuda.is_available():
392-
work.block_current_stream()
393-
394-
fut = work.get_future()
395-
396-
stream: Optional[torch.cuda.Stream] = (
397-
torch.cuda.current_stream() if torch.cuda.is_available() else None
398-
)
399-
400-
# schedule grad normalization as a continuation
401-
# on the Future
402-
@torch.profiler.record_function("torchft::manager::allreduce::callback")
403-
def callback(
404-
fut: torch.futures.Future[List[torch.Tensor]],
405-
) -> torch.Tensor:
406-
nonlocal tensor, stream, num_participants
407-
408-
# change the stream to avoid making the callback stream
409-
# dependent on process group stream running the allreduce
410-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
411-
# Setup stream dependency
412-
fut.wait()
413-
fut.value()
414-
tensor /= num_participants
415407

416-
return tensor
417-
418-
fut = fut.then(callback)
419-
420-
fut = self.wrap_future(fut, tensor)
421-
422-
return _WorkWrapper(work, fut)
408+
return _WorkWrapper(work, self, tensor, num_participants)
423409

424410
except Exception as e:
425411
self._logger.exception(
@@ -939,3 +925,92 @@ def warn(self, msg: str) -> None:
939925

940926
def exception(self, msg: str) -> None:
941927
self._logger.exception(f"{self.prefix()} {msg}")
928+
929+
930+
class _WorkWrapper(dist._Work):
931+
def __init__(
932+
self,
933+
work: dist._Work,
934+
manager: Manager,
935+
tensor: torch.Tensor,
936+
num_participants: int,
937+
) -> None:
938+
super().__init__()
939+
self._manager = manager
940+
self._work = work
941+
self._tensor = tensor
942+
self._num_participants = num_participants
943+
self._fut: Union[
944+
torch.futures.Future[torch.Tensor], torch.futures.Future[None]
945+
] = work.get_future()
946+
947+
self._stream: Optional[torch.cuda.Stream] = (
948+
torch.cuda.current_stream() if torch.cuda.is_available() else None
949+
)
950+
951+
self._is_set_future_callback_called = False
952+
953+
def _set_future_callback(
954+
self,
955+
) -> None:
956+
if self._is_set_future_callback_called:
957+
return
958+
959+
# schedule grad normalization as a continuation
960+
# on the Future
961+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
962+
def callback(
963+
fut: torch.futures.Future[List[torch.Tensor]],
964+
) -> torch.Tensor:
965+
# change the stream to avoid making the callback stream
966+
# dependent on process group stream running the allreduce
967+
with (
968+
torch.cuda.stream(self._stream)
969+
if self._stream is not None
970+
else nullcontext()
971+
):
972+
# Setup stream dependency
973+
fut.wait()
974+
self._tensor /= self._num_participants
975+
976+
return self._tensor
977+
978+
fut = self._fut
979+
fut = fut.then(callback)
980+
fut = self._manager.wrap_future(fut, self._tensor)
981+
self._fut = fut
982+
983+
self._is_set_future_callback_called = True
984+
985+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
986+
with (
987+
torch.cuda.stream(self._stream)
988+
if self._stream is not None
989+
else nullcontext()
990+
):
991+
self._work.wait()
992+
993+
self._set_future_callback()
994+
995+
return True
996+
997+
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
998+
with (
999+
torch.cuda.stream(self._stream)
1000+
if self._stream is not None
1001+
else nullcontext()
1002+
):
1003+
self._work.block_current_stream()
1004+
1005+
self._set_future_callback()
1006+
1007+
def get_future(
1008+
self,
1009+
) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]:
1010+
if torch.cuda.is_available():
1011+
self.block_current_stream()
1012+
else:
1013+
# No stream dependencies need to be set
1014+
self._set_future_callback()
1015+
1016+
return self._fut

torchft/work.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1818

1919
def get_future(self) -> torch.futures.Future[object]:
2020
return self.future_
21-
22-
23-
class _WorkWrapper(dist._Work):
24-
def __init__(
25-
self, work: dist._Work, fut: torch.futures.Future[torch.Tensor]
26-
) -> None:
27-
super().__init__()
28-
self._work = work
29-
self._fut = fut
30-
31-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
32-
self._fut.wait()
33-
return True
34-
35-
def get_future(self) -> torch.futures.Future[torch.Tensor]:
36-
return self._fut

0 commit comments

Comments
 (0)