Skip to content

setup stream dependencies inside work wrapper #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchft/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _comm_hook(
state: "Manager", bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
work = state.allreduce(bucket.buffer())
work.synchronize()
return work.get_future()


Expand Down
1 change: 1 addition & 0 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
)

work.synchronize()
fut = work.get_future()
fut.add_done_callback(callback)

Expand Down
153 changes: 120 additions & 33 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,28 @@
from contextlib import nullcontext
from datetime import timedelta
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
TypeAlias,
TypeVar,
Union,
cast,
)

import torch
import torch.distributed as dist
from torch.distributed import ReduceOp, TCPStore
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work

from torchft._torchft import ManagerClient, ManagerServer
from torchft.checkpointing import CheckpointTransport, HTTPTransport
from torchft.futures import future_timeout
from torchft.work import _DummyWork, _WorkWrapper
from torchft.work import _DummyWork

if TYPE_CHECKING:
from torchft.process_group import ProcessGroup
Expand Down Expand Up @@ -344,7 +356,11 @@ def shutdown(self, wait: bool = True) -> None:
self._executor.shutdown(wait=wait)

@torch.profiler.record_function("torchft::manager::allreduce")
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
def allreduce(
self,
tensor: torch.Tensor,
should_quantize: bool = False,
) -> Work:
"""
Fault tolerant allreduce the tensor and return a Future that will be completed when
the tensor is ready.
Expand Down Expand Up @@ -382,37 +398,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
)
else:
work = self._pg.allreduce([tensor], ReduceOp.SUM)
work.wait()

fut = work.get_future()

stream: Optional[torch.cuda.Stream] = (
torch.cuda.current_stream() if torch.cuda.is_available() else None
)

# schedule grad normalization as a continuation
# on the Future
@torch.profiler.record_function("torchft::manager::allreduce::callback")
def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> torch.Tensor:
nonlocal tensor, stream, num_participants

# change the stream to avoid making the callback stream
# dependent on process group stream running the allreduce
with torch.cuda.stream(stream) if stream is not None else nullcontext():
# Setup stream dependency
fut.wait()
fut.value()
tensor /= num_participants

return tensor

fut = fut.then(callback)

fut = self.wrap_future(fut, tensor)

return _WorkWrapper(work, fut)
return _ManagedWork(work, self, tensor, num_participants)

except Exception as e:
self._logger.exception(
Expand Down Expand Up @@ -932,3 +919,103 @@ def warn(self, msg: str) -> None:

def exception(self, msg: str) -> None:
self._logger.exception(f"{self.prefix()} {msg}")


class _ManagedWork(dist._Work):
def __init__(
self,
work: dist._Work,
manager: Manager,
tensor: torch.Tensor,
num_participants: int,
) -> None:
super().__init__()
self._manager = manager
self._work = work
self._tensor = tensor
self._num_participants = num_participants
self._fut: Union[
torch.futures.Future[torch.Tensor], torch.futures.Future[None]
] = work.get_future()

self._stream: Optional[torch.cuda.Stream] = (
torch.cuda.current_stream() if torch.cuda.is_available() else None
)

self._is_set_future_callback_called = False

def _set_future_callback(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to live in the Work object? Can't we pass the stream + the future to the _WorkWrapper and have it manage things correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be ideal but doesn't work in all cases i think

  • for nccl, we need to call work.wait() before doing everything in _set_future_callback(), otherwise stream dependency is not hooked up in the right order i think i.e. we could end up calling future.wait before work.wait
  • for cpu, we can't call work.wait() because that'll block

these 2 conflict with each other, so this is what i came up with

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tushar00jain you can use work.synchronize() to setup the dependency in a guaranteed non-blocking way

Copy link
Contributor Author

@tushar00jain tushar00jain Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@d4l3k that's for nccl right? for nccl you mentioned block_current_stream also just calls synchronize but it also works in a non-blocking way for gloo. we needed block_current_stream for that because i'm guessing syncrhronize does't do that for gloo.

also based on our discussion offline, the current api's work for all cases and have the same semantics as the underlying process group work

  • in torchft, we only ever use work.wait(), and we call it only when we need to synchronize
    • for nccl, and gloo with cuda, this sets up stream deps properly with a custom stream that we synchronize on to wait for the allreduce to finish along with the future associated with that work
    • for gloo with cpu, it just blocks until the work is done. the future callbacks run after the work is done
  • that was a lie, we also call work.get_future() in bucketized allreduce
    • in this case we call block_current_stream first to set up stream dep for nccl (just a proxy to work.synchronize), and gloo with cuda. we also add a callback to the future chain but carefully set up the stream dep after all the other stream deps have been set up. that's why we call block_current_stream in get_future anyway
    • for gloo with cpu, it doesn't call anything on work because futures anyway run after the work is done
  • we will call work.block_current_stream for hsdp in torchtitan -- this is pretty much the same as the case above for bucketized allreduce
  • for ddp, we call get_future but don't expect users to do anything besides calling .wait on that future

In the future,

  • we can consider creating our own future instead of using torch.futures.Future that sets up stream deps like the way we want it to
  • consider simplifying the implementation of _ManagedWork (the above will also help us do that)

self,
) -> None:
if self._is_set_future_callback_called:
return

# schedule grad normalization as a continuation
# on the Future
@torch.profiler.record_function("torchft::manager::allreduce::callback")
def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> torch.Tensor:
# change the stream to avoid making the callback stream
# dependent on process group stream running the allreduce
with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
# Setup stream dependency
fut.wait()
self._tensor /= self._num_participants

return self._tensor

fut = self._fut
fut = fut.then(callback)
fut = self._manager.wrap_future(fut, self._tensor)
self._fut = fut

self._is_set_future_callback_called = True

def wait(self, timeout: Optional[timedelta] = None) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.wait() should set a dependency between the work and the current stream -- it looks like we're running all operations on self._stream?

with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
self._work.wait()

self._set_future_callback()

with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
self._fut.wait()

return True

def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably shouldn't rely on this until we've thought this through more / tested

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we can test it more before we change the hsdp implementation. think we can also do some other alternative for bucketized allreduce and ddp without having to use block_current_stream

with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
self._work.block_current_stream()

self._set_future_callback()

def synchronize(self) -> None:
if torch.cuda.is_available():
self.block_current_stream()
else:
# No stream dependencies need to be set
self._set_future_callback()

def get_future(
self,
) -> Union[torch.futures.Future[torch.Tensor], torch.futures.Future[None]]:
assert (
self._is_set_future_callback_called
), "getting the future without calling synchronize() is unsafe"
return self._fut
2 changes: 2 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:

self.assertTrue(manager.is_participating())
work = manager.allreduce(torch.tensor([1.0]))
work.synchronize()
fut = work.get_future()
result = fut.value()
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))
Expand All @@ -596,6 +597,7 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
manager._healing = True
self.assertFalse(manager.is_participating())
work = manager.allreduce(torch.tensor([1.0]))
work.synchronize()
fut = work.get_future()
result = fut.value()
torch.testing.assert_close(result, torch.tensor([0.0]))
Expand Down
16 changes: 0 additions & 16 deletions torchft/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,3 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:

def get_future(self) -> torch.futures.Future[object]:
return self.future_


class _WorkWrapper(dist._Work):
def __init__(
self, work: dist._Work, fut: torch.futures.Future[torch.Tensor]
) -> None:
super().__init__()
self._work = work
self._fut = fut

def wait(self, timeout: Optional[timedelta] = None) -> bool:
self._fut.wait()
return True

def get_future(self) -> torch.futures.Future[torch.Tensor]:
return self._fut