Skip to content

Commit fc31fd4

Browse files
committed
setup stream dependencies inside work wrapper
Summary: - extend the work wrapper object to also do the division post allreduce - add api to block_current_stream on work wrapper so it can be used for HSDP
1 parent fea3431 commit fc31fd4

File tree

5 files changed

+124
-49
lines changed

5 files changed

+124
-49
lines changed

torchft/ddp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
7171
work = state.allreduce(bucket.buffer())
72+
work.synchronize()
7273
return work.get_future()
7374

7475

torchft/local_sgd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
535535
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
536536
)
537537

538+
work.synchronize()
538539
fut = work.get_future()
539540
fut.add_done_callback(callback)
540541

torchft/manager.py

Lines changed: 120 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,29 @@
3535
from contextlib import nullcontext
3636
from datetime import timedelta
3737
from enum import Enum
38-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
38+
from typing import (
39+
TYPE_CHECKING,
40+
Any,
41+
Callable,
42+
Dict,
43+
List,
44+
Optional,
45+
TypeAlias,
46+
TypeVar,
47+
Union,
48+
cast,
49+
)
3950

4051
import torch
52+
import torch.distributed as dist
4153
from torch.distributed import ReduceOp, TCPStore
4254
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4355

4456
from torchft._torchft import ManagerClient, ManagerServer
4557
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4658
from torchft.checkpointing._rwlock import RWLock
4759
from torchft.futures import future_timeout
48-
from torchft.work import _DummyWork, _WorkWrapper
60+
from torchft.work import _DummyWork
4961

5062
if TYPE_CHECKING:
5163
from torchft.process_group import ProcessGroup
@@ -363,7 +375,11 @@ def shutdown(self, wait: bool = True) -> None:
363375
self._executor.shutdown(wait=wait)
364376

365377
@torch.profiler.record_function("torchft::manager::allreduce")
366-
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
378+
def allreduce(
379+
self,
380+
tensor: torch.Tensor,
381+
should_quantize: bool = False,
382+
) -> Work:
367383
"""
368384
Fault tolerant allreduce the tensor and return a Future that will be completed when
369385
the tensor is ready.
@@ -401,37 +417,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
401417
)
402418
else:
403419
work = self._pg.allreduce([tensor], ReduceOp.SUM)
404-
work.wait()
405-
406-
fut = work.get_future()
407-
408-
stream: Optional[torch.cuda.Stream] = (
409-
torch.cuda.current_stream() if torch.cuda.is_available() else None
410-
)
411-
412-
# schedule grad normalization as a continuation
413-
# on the Future
414-
@torch.profiler.record_function("torchft::manager::allreduce::callback")
415-
def callback(
416-
fut: torch.futures.Future[List[torch.Tensor]],
417-
) -> torch.Tensor:
418-
nonlocal tensor, stream, num_participants
419-
420-
# change the stream to avoid making the callback stream
421-
# dependent on process group stream running the allreduce
422-
with torch.cuda.stream(stream) if stream is not None else nullcontext():
423-
# Setup stream dependency
424-
fut.wait()
425-
fut.value()
426-
tensor /= num_participants
427420

428-
return tensor
429-
430-
fut = fut.then(callback)
431-
432-
fut = self.wrap_future(fut, tensor)
433-
434-
return _WorkWrapper(work, fut)
421+
return _ManagedWork(work, self, tensor, num_participants)
435422

436423
except Exception as e:
437424
self._logger.exception(
@@ -954,3 +941,103 @@ def warn(self, msg: str) -> None:
954941

955942
def exception(self, msg: str) -> None:
956943
self._logger.exception(f"{self.prefix()} {msg}")
944+
945+
946+
class _ManagedWork(dist._Work):
947+
def __init__(
948+
self,
949+
work: dist._Work,
950+
manager: Manager,
951+
tensor: torch.Tensor,
952+
num_participants: int,
953+
) -> None:
954+
super().__init__()
955+
self._manager = manager
956+
self._work = work
957+
self._tensor = tensor
958+
self._num_participants = num_participants
959+
self._fut: Union[
960+
torch.futures.Future[torch.Tensor], torch.futures.Future[None]
961+
] = work.get_future()
962+
963+
self._stream: Optional[torch.cuda.Stream] = (
964+
torch.cuda.current_stream() if torch.cuda.is_available() else None
965+
)
966+
967+
self._is_set_future_callback_called = False
968+
969+
def _set_future_callback(
970+
self,
971+
) -> None:
972+
if self._is_set_future_callback_called:
973+
return
974+
975+
# schedule grad normalization as a continuation
976+
# on the Future
977+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
978+
def callback(
979+
fut: torch.futures.Future[List[torch.Tensor]],
980+
) -> torch.Tensor:
981+
# change the stream to avoid making the callback stream
982+
# dependent on process group stream running the allreduce
983+
with (
984+
torch.cuda.stream(self._stream)
985+
if self._stream is not None
986+
else nullcontext()
987+
):
988+
# Setup stream dependency
989+
fut.wait()
990+
self._tensor /= self._num_participants
991+
992+
return self._tensor
993+
994+
fut = self._fut
995+
fut = fut.then(callback)
996+
fut = self._manager.wrap_future(fut, self._tensor)
997+
self._fut = fut
998+
999+
self._is_set_future_callback_called = True
1000+
1001+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
1002+
with (
1003+
torch.cuda.stream(self._stream)
1004+
if self._stream is not None
1005+
else nullcontext()
1006+
):
1007+
self._work.wait()
1008+
1009+
self._set_future_callback()
1010+
1011+
with (
1012+
torch.cuda.stream(self._stream)
1013+
if self._stream is not None
1014+
else nullcontext()
1015+
):
1016+
self._fut.wait()
1017+
1018+
return True
1019+
1020+
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
1021+
with (
1022+
torch.cuda.stream(self._stream)
1023+
if self._stream is not None
1024+
else nullcontext()
1025+
):
1026+
self._work.block_current_stream()
1027+
1028+
self._set_future_callback()
1029+
1030+
def synchronize(self) -> None:
1031+
if torch.cuda.is_available():
1032+
self.block_current_stream()
1033+
else:
1034+
# No stream dependencies need to be set
1035+
self._set_future_callback()
1036+
1037+
def get_future(
1038+
self,
1039+
) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]:
1040+
assert (
1041+
self._is_set_future_callback_called
1042+
), "getting the future without calling synchronize() is unsafe"
1043+
return self._fut

torchft/manager_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
588588

589589
self.assertTrue(manager.is_participating())
590590
work = manager.allreduce(torch.tensor([1.0]))
591+
work.synchronize()
591592
fut = work.get_future()
592593
result = fut.value()
593594
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))
@@ -596,6 +597,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
596597
manager._healing = True
597598
self.assertFalse(manager.is_participating())
598599
work = manager.allreduce(torch.tensor([1.0]))
600+
work.synchronize()
599601
fut = work.get_future()
600602
result = fut.value()
601603
torch.testing.assert_close(result, torch.tensor([0.0]))

torchft/work.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1818

1919
def get_future(self) -> torch.futures.Future[object]:
2020
return self.future_
21-
22-
23-
class _WorkWrapper(dist._Work):
24-
def __init__(
25-
self, work: dist._Work, fut: torch.futures.Future[torch.Tensor]
26-
) -> None:
27-
super().__init__()
28-
self._work = work
29-
self._fut = fut
30-
31-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
32-
self._fut.wait()
33-
return True
34-
35-
def get_future(self) -> torch.futures.Future[torch.Tensor]:
36-
return self._fut

0 commit comments

Comments
 (0)