-
Notifications
You must be signed in to change notification settings - Fork 39
setup stream dependencies inside work wrapper #248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,16 +35,28 @@ | |
from contextlib import nullcontext | ||
from datetime import timedelta | ||
from enum import Enum | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Callable, | ||
Dict, | ||
List, | ||
Optional, | ||
TypeAlias, | ||
TypeVar, | ||
Union, | ||
cast, | ||
) | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from torch.distributed import ReduceOp, TCPStore | ||
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work | ||
|
||
from torchft._torchft import ManagerClient, ManagerServer | ||
from torchft.checkpointing import CheckpointTransport, HTTPTransport | ||
from torchft.futures import future_timeout | ||
from torchft.work import _DummyWork, _WorkWrapper | ||
from torchft.work import _DummyWork | ||
|
||
if TYPE_CHECKING: | ||
from torchft.process_group import ProcessGroup | ||
|
@@ -344,7 +356,11 @@ def shutdown(self, wait: bool = True) -> None: | |
self._executor.shutdown(wait=wait) | ||
|
||
@torch.profiler.record_function("torchft::manager::allreduce") | ||
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work: | ||
def allreduce( | ||
self, | ||
tensor: torch.Tensor, | ||
should_quantize: bool = False, | ||
) -> Work: | ||
""" | ||
Fault tolerant allreduce the tensor and return a Future that will be completed when | ||
the tensor is ready. | ||
|
@@ -382,37 +398,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work | |
) | ||
else: | ||
work = self._pg.allreduce([tensor], ReduceOp.SUM) | ||
work.wait() | ||
|
||
fut = work.get_future() | ||
|
||
stream: Optional[torch.cuda.Stream] = ( | ||
torch.cuda.current_stream() if torch.cuda.is_available() else None | ||
) | ||
|
||
# schedule grad normalization as a continuation | ||
# on the Future | ||
@torch.profiler.record_function("torchft::manager::allreduce::callback") | ||
def callback( | ||
fut: torch.futures.Future[List[torch.Tensor]], | ||
) -> torch.Tensor: | ||
nonlocal tensor, stream, num_participants | ||
|
||
# change the stream to avoid making the callback stream | ||
# dependent on process group stream running the allreduce | ||
with torch.cuda.stream(stream) if stream is not None else nullcontext(): | ||
# Setup stream dependency | ||
fut.wait() | ||
fut.value() | ||
tensor /= num_participants | ||
|
||
return tensor | ||
|
||
fut = fut.then(callback) | ||
|
||
fut = self.wrap_future(fut, tensor) | ||
|
||
return _WorkWrapper(work, fut) | ||
return _ManagedWork(work, self, tensor, num_participants) | ||
|
||
except Exception as e: | ||
self._logger.exception( | ||
|
@@ -932,3 +919,103 @@ def warn(self, msg: str) -> None: | |
|
||
def exception(self, msg: str) -> None: | ||
self._logger.exception(f"{self.prefix()} {msg}") | ||
|
||
|
||
class _ManagedWork(dist._Work): | ||
def __init__( | ||
self, | ||
work: dist._Work, | ||
manager: Manager, | ||
tensor: torch.Tensor, | ||
num_participants: int, | ||
) -> None: | ||
super().__init__() | ||
self._manager = manager | ||
self._work = work | ||
self._tensor = tensor | ||
self._num_participants = num_participants | ||
self._fut: Union[ | ||
torch.futures.Future[torch.Tensor], torch.futures.Future[None] | ||
] = work.get_future() | ||
|
||
self._stream: Optional[torch.cuda.Stream] = ( | ||
torch.cuda.current_stream() if torch.cuda.is_available() else None | ||
) | ||
|
||
self._is_set_future_callback_called = False | ||
|
||
def _set_future_callback( | ||
self, | ||
) -> None: | ||
if self._is_set_future_callback_called: | ||
return | ||
|
||
# schedule grad normalization as a continuation | ||
# on the Future | ||
@torch.profiler.record_function("torchft::manager::allreduce::callback") | ||
def callback( | ||
fut: torch.futures.Future[List[torch.Tensor]], | ||
) -> torch.Tensor: | ||
# change the stream to avoid making the callback stream | ||
# dependent on process group stream running the allreduce | ||
with ( | ||
torch.cuda.stream(self._stream) | ||
if self._stream is not None | ||
else nullcontext() | ||
): | ||
# Setup stream dependency | ||
fut.wait() | ||
self._tensor /= self._num_participants | ||
|
||
return self._tensor | ||
|
||
fut = self._fut | ||
fut = fut.then(callback) | ||
fut = self._manager.wrap_future(fut, self._tensor) | ||
self._fut = fut | ||
|
||
self._is_set_future_callback_called = True | ||
|
||
def wait(self, timeout: Optional[timedelta] = None) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. .wait() should set a dependency between the work and the current stream -- it looks like we're running all operations on self._stream? |
||
with ( | ||
torch.cuda.stream(self._stream) | ||
if self._stream is not None | ||
else nullcontext() | ||
): | ||
self._work.wait() | ||
|
||
self._set_future_callback() | ||
|
||
with ( | ||
torch.cuda.stream(self._stream) | ||
if self._stream is not None | ||
else nullcontext() | ||
): | ||
self._fut.wait() | ||
|
||
return True | ||
|
||
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we probably shouldn't rely on this until we've thought this through more / tested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah we can test it more before we change the hsdp implementation. think we can also do some other alternative for bucketized allreduce and ddp without having to use block_current_stream |
||
with ( | ||
torch.cuda.stream(self._stream) | ||
if self._stream is not None | ||
else nullcontext() | ||
): | ||
self._work.block_current_stream() | ||
|
||
self._set_future_callback() | ||
|
||
def synchronize(self) -> None: | ||
if torch.cuda.is_available(): | ||
self.block_current_stream() | ||
else: | ||
# No stream dependencies need to be set | ||
self._set_future_callback() | ||
|
||
def get_future( | ||
self, | ||
) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]: | ||
assert ( | ||
self._is_set_future_callback_called | ||
), "getting the future without calling synchronize() is unsafe" | ||
return self._fut |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to live in the Work object? Can't we pass the stream + the future to the _WorkWrapper and have it manage things correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that would be ideal but doesn't work in all cases i think
these 2 conflict with each other, so this is what i came up with
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tushar00jain you can use
work.synchronize()
to setup the dependency in a guaranteed non-blocking wayUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@d4l3k that's for nccl right? for nccl you mentioned
block_current_stream
also just callssynchronize
but it also works in a non-blocking way for gloo. we neededblock_current_stream
for that because i'm guessingsyncrhronize
does't do that for gloo.also based on our discussion offline, the current api's work for all cases and have the same semantics as the underlying process group work
work.wait()
, and we call it only when we need to synchronizework.get_future()
in bucketized allreduceblock_current_stream
first to set up stream dep for nccl (just a proxy to work.synchronize), and gloo with cuda. we also add a callback to the future chain but carefully set up the stream dep after all the other stream deps have been set up. that's why we callblock_current_stream
in get_future anywayIn the future,
_ManagedWork
(the above will also help us do that)