|
39 | 39 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
|
40 | 40 |
|
41 | 41 | import torch
|
| 42 | +import torch.distributed as dist |
42 | 43 | from torch.distributed import ReduceOp, TCPStore
|
43 | 44 | from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
|
44 | 45 |
|
45 | 46 | from torchft._torchft import ManagerClient, ManagerServer
|
46 | 47 | from torchft.checkpointing import CheckpointTransport, HTTPTransport
|
47 | 48 | from torchft.futures import future_timeout
|
48 |
| -from torchft.work import _DummyWork, _WorkWrapper |
| 49 | +from torchft.work import _DummyWork |
49 | 50 |
|
50 | 51 | if TYPE_CHECKING:
|
51 | 52 | from torchft.process_group import ProcessGroup
|
@@ -383,37 +384,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
|
383 | 384 | )
|
384 | 385 | else:
|
385 | 386 | work = self._pg.allreduce([tensor], ReduceOp.SUM)
|
386 |
| - work.wait() |
387 | 387 |
|
388 |
| - fut = work.get_future() |
389 |
| - |
390 |
| - stream: Optional[torch.cuda.Stream] = ( |
391 |
| - torch.cuda.current_stream() if torch.cuda.is_available() else None |
392 |
| - ) |
393 |
| - |
394 |
| - # schedule grad normalization as a continuation |
395 |
| - # on the Future |
396 |
| - @torch.profiler.record_function("torchft::manager::allreduce::callback") |
397 |
| - def callback( |
398 |
| - fut: torch.futures.Future[List[torch.Tensor]], |
399 |
| - ) -> torch.Tensor: |
400 |
| - nonlocal tensor, stream, num_participants |
401 |
| - |
402 |
| - # change the stream to avoid making the callback stream |
403 |
| - # dependent on process group stream running the allreduce |
404 |
| - with torch.cuda.stream(stream) if stream is not None else nullcontext(): |
405 |
| - # Setup stream dependency |
406 |
| - fut.wait() |
407 |
| - fut.value() |
408 |
| - tensor /= num_participants |
409 |
| - |
410 |
| - return tensor |
411 |
| - |
412 |
| - fut = fut.then(callback) |
413 |
| - |
414 |
| - fut = self.wrap_future(fut, tensor) |
415 |
| - |
416 |
| - return _WorkWrapper(work, fut) |
| 388 | + return _WorkWrapper(work, self, tensor, num_participants) |
417 | 389 |
|
418 | 390 | except Exception as e:
|
419 | 391 | self._logger.exception(
|
@@ -933,3 +905,59 @@ def warn(self, msg: str) -> None:
|
933 | 905 |
|
934 | 906 | def exception(self, msg: str) -> None:
|
935 | 907 | self._logger.exception(f"{self.prefix()} {msg}")
|
| 908 | + |
| 909 | + |
| 910 | +class _WorkWrapper(dist._Work): |
| 911 | + def __init__( |
| 912 | + self, |
| 913 | + work: dist._Work, |
| 914 | + manager: Manager, |
| 915 | + tensor: torch.Tensor, |
| 916 | + num_participants: int, |
| 917 | + ) -> None: |
| 918 | + super().__init__() |
| 919 | + self._manager = manager |
| 920 | + self._work = work |
| 921 | + self._tensor = tensor |
| 922 | + self._num_participants = num_participants |
| 923 | + |
| 924 | + self._fut = self._work.get_future() |
| 925 | + self._stream: Optional[torch.cuda.Stream] = ( |
| 926 | + torch.cuda.current_stream() if torch.cuda.is_available() else None |
| 927 | + ) |
| 928 | + |
| 929 | + def wait(self, timeout: Optional[timedelta] = None) -> bool: |
| 930 | + with ( |
| 931 | + torch.cuda.stream(self._stream) |
| 932 | + if self._stream is not None |
| 933 | + else nullcontext() |
| 934 | + ): |
| 935 | + self._work.wait() |
| 936 | + |
| 937 | + # schedule grad normalization as a continuation |
| 938 | + # on the Future |
| 939 | + @torch.profiler.record_function("torchft::manager::allreduce::callback") |
| 940 | + def callback( |
| 941 | + fut: torch.futures.Future[List[torch.Tensor]], |
| 942 | + ) -> torch.Tensor: |
| 943 | + # change the stream to avoid making the callback stream |
| 944 | + # dependent on process group stream running the allreduce |
| 945 | + with ( |
| 946 | + torch.cuda.stream(self._stream) |
| 947 | + if self._stream is not None |
| 948 | + else nullcontext() |
| 949 | + ): |
| 950 | + # Setup stream dependency |
| 951 | + fut.wait() |
| 952 | + self._tensor /= self._num_participants |
| 953 | + |
| 954 | + return self._tensor |
| 955 | + |
| 956 | + self._fut = self._fut.then(callback) |
| 957 | + self._fut = self._manager.wrap_future(self._fut, self._tensor) |
| 958 | + |
| 959 | + return True |
| 960 | + |
| 961 | + def get_future(self) -> torch.futures.Future[torch.Tensor]: |
| 962 | + self.wait() |
| 963 | + return self._fut |
0 commit comments