|
35 | 35 | from contextlib import nullcontext
|
36 | 36 | from datetime import timedelta
|
37 | 37 | from enum import Enum
|
38 |
| -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast |
| 38 | +from typing import ( |
| 39 | + TYPE_CHECKING, |
| 40 | + Any, |
| 41 | + Callable, |
| 42 | + Dict, |
| 43 | + List, |
| 44 | + Optional, |
| 45 | + TypeAlias, |
| 46 | + TypeVar, |
| 47 | + Union, |
| 48 | + cast, |
| 49 | +) |
39 | 50 |
|
40 | 51 | import torch
|
| 52 | +import torch.distributed as dist |
41 | 53 | from torch.distributed import ReduceOp, TCPStore
|
42 | 54 | from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
|
43 | 55 |
|
44 | 56 | from torchft._torchft import ManagerClient, ManagerServer
|
45 | 57 | from torchft.checkpointing import CheckpointTransport, HTTPTransport
|
46 | 58 | from torchft.futures import future_timeout
|
47 |
| -from torchft.work import _DummyWork, _WorkWrapper |
| 59 | +from torchft.work import _DummyWork |
48 | 60 |
|
49 | 61 | if TYPE_CHECKING:
|
50 | 62 | from torchft.process_group import ProcessGroup
|
@@ -350,7 +362,11 @@ def shutdown(self, wait: bool = True) -> None:
|
350 | 362 | self._executor.shutdown(wait=wait)
|
351 | 363 |
|
352 | 364 | @torch.profiler.record_function("torchft::manager::allreduce")
|
353 |
| - def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work: |
| 365 | + def allreduce( |
| 366 | + self, |
| 367 | + tensor: torch.Tensor, |
| 368 | + should_quantize: bool = False, |
| 369 | + ) -> Work: |
354 | 370 | """
|
355 | 371 | Fault tolerant allreduce the tensor and return a Future that will be completed when
|
356 | 372 | the tensor is ready.
|
@@ -388,38 +404,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
|
388 | 404 | )
|
389 | 405 | else:
|
390 | 406 | work = self._pg.allreduce([tensor], ReduceOp.SUM)
|
391 |
| - if torch.cuda.is_available(): |
392 |
| - work.block_current_stream() |
393 |
| - |
394 |
| - fut = work.get_future() |
395 |
| - |
396 |
| - stream: Optional[torch.cuda.Stream] = ( |
397 |
| - torch.cuda.current_stream() if torch.cuda.is_available() else None |
398 |
| - ) |
399 |
| - |
400 |
| - # schedule grad normalization as a continuation |
401 |
| - # on the Future |
402 |
| - @torch.profiler.record_function("torchft::manager::allreduce::callback") |
403 |
| - def callback( |
404 |
| - fut: torch.futures.Future[List[torch.Tensor]], |
405 |
| - ) -> torch.Tensor: |
406 |
| - nonlocal tensor, stream, num_participants |
407 |
| - |
408 |
| - # change the stream to avoid making the callback stream |
409 |
| - # dependent on process group stream running the allreduce |
410 |
| - with torch.cuda.stream(stream) if stream is not None else nullcontext(): |
411 |
| - # Setup stream dependency |
412 |
| - fut.wait() |
413 |
| - fut.value() |
414 |
| - tensor /= num_participants |
415 | 407 |
|
416 |
| - return tensor |
417 |
| - |
418 |
| - fut = fut.then(callback) |
419 |
| - |
420 |
| - fut = self.wrap_future(fut, tensor) |
421 |
| - |
422 |
| - return _WorkWrapper(work, fut) |
| 408 | + return _WorkWrapper(work, self, tensor, num_participants) |
423 | 409 |
|
424 | 410 | except Exception as e:
|
425 | 411 | self._logger.exception(
|
@@ -939,3 +925,92 @@ def warn(self, msg: str) -> None:
|
939 | 925 |
|
940 | 926 | def exception(self, msg: str) -> None:
|
941 | 927 | self._logger.exception(f"{self.prefix()} {msg}")
|
| 928 | + |
| 929 | + |
| 930 | +class _WorkWrapper(dist._Work): |
| 931 | + def __init__( |
| 932 | + self, |
| 933 | + work: dist._Work, |
| 934 | + manager: Manager, |
| 935 | + tensor: torch.Tensor, |
| 936 | + num_participants: int, |
| 937 | + ) -> None: |
| 938 | + super().__init__() |
| 939 | + self._manager = manager |
| 940 | + self._work = work |
| 941 | + self._tensor = tensor |
| 942 | + self._num_participants = num_participants |
| 943 | + self._fut: Union[ |
| 944 | + torch.futures.Future[torch.Tensor], torch.futures.Future[None] |
| 945 | + ] = work.get_future() |
| 946 | + |
| 947 | + self._stream: Optional[torch.cuda.Stream] = ( |
| 948 | + torch.cuda.current_stream() if torch.cuda.is_available() else None |
| 949 | + ) |
| 950 | + |
| 951 | + self._is_set_future_callback_called = False |
| 952 | + |
| 953 | + def _set_future_callback( |
| 954 | + self, |
| 955 | + ) -> None: |
| 956 | + if self._is_set_future_callback_called: |
| 957 | + return |
| 958 | + |
| 959 | + # schedule grad normalization as a continuation |
| 960 | + # on the Future |
| 961 | + @torch.profiler.record_function("torchft::manager::allreduce::callback") |
| 962 | + def callback( |
| 963 | + fut: torch.futures.Future[List[torch.Tensor]], |
| 964 | + ) -> torch.Tensor: |
| 965 | + # change the stream to avoid making the callback stream |
| 966 | + # dependent on process group stream running the allreduce |
| 967 | + with ( |
| 968 | + torch.cuda.stream(self._stream) |
| 969 | + if self._stream is not None |
| 970 | + else nullcontext() |
| 971 | + ): |
| 972 | + # Setup stream dependency |
| 973 | + fut.wait() |
| 974 | + self._tensor /= self._num_participants |
| 975 | + |
| 976 | + return self._tensor |
| 977 | + |
| 978 | + fut = self._fut |
| 979 | + fut = fut.then(callback) |
| 980 | + fut = self._manager.wrap_future(fut, self._tensor) |
| 981 | + self._fut = fut |
| 982 | + |
| 983 | + self._is_set_future_callback_called = True |
| 984 | + |
| 985 | + def wait(self, timeout: Optional[timedelta] = None) -> bool: |
| 986 | + with ( |
| 987 | + torch.cuda.stream(self._stream) |
| 988 | + if self._stream is not None |
| 989 | + else nullcontext() |
| 990 | + ): |
| 991 | + self._work.wait() |
| 992 | + |
| 993 | + self._set_future_callback() |
| 994 | + |
| 995 | + return True |
| 996 | + |
| 997 | + def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: |
| 998 | + with ( |
| 999 | + torch.cuda.stream(self._stream) |
| 1000 | + if self._stream is not None |
| 1001 | + else nullcontext() |
| 1002 | + ): |
| 1003 | + self._work.block_current_stream() |
| 1004 | + |
| 1005 | + self._set_future_callback() |
| 1006 | + |
| 1007 | + def get_future( |
| 1008 | + self, |
| 1009 | + ) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]: |
| 1010 | + if torch.cuda.is_available(): |
| 1011 | + self.block_current_stream() |
| 1012 | + else: |
| 1013 | + # No stream dependencies need to be set |
| 1014 | + self._set_future_callback() |
| 1015 | + |
| 1016 | + return self._fut |
0 commit comments