Skip to content

return work from manager allreduce #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion torchft/checkpointing/http_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Generator, List, Optional, TypeVar, cast

import torch
from torch.distributed.tensor import DTensor, distribute_tensor
from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten

from torchft.checkpointing._rwlock import RWLock
Expand Down Expand Up @@ -266,6 +267,15 @@ def recv_checkpoint(
return tree_unflatten(values, spec)


def _clone_cpu_tensor(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, DTensor):
return distribute_tensor(
tensor.to_local().clone(), tensor.device_mesh, tensor.placements
)
else:
return tensor.clone()


def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
out = []
for v in values:
Expand All @@ -278,7 +288,7 @@ def _to_cpu(values: List[T], pin_memory: bool) -> List[T]:
else:
out.append(v.cpu())
else:
out.append(v)
out.append(_clone_cpu_tensor(v))
else:
out.append(v)
return out
Expand Down
3 changes: 3 additions & 0 deletions torchft/checkpointing/pg_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,9 @@ def metadata(self) -> str:
def disallow_checkpoint(self) -> None:
pass

def allow_checkpoint(self, step: int) -> None:
pass

def send_checkpoint(
self, dst_ranks: list[int], step: int, state_dict: T, timeout: timedelta
) -> None:
Expand Down
9 changes: 9 additions & 0 deletions torchft/checkpointing/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def disallow_checkpoint(self) -> None:
"""
...

def allow_checkpoint(self, step: int) -> None:
"""
Called when checkpoint is allowed to be sent to make sure access to the state_dict is safe.

Args:
step: the step number that the checkpoint is for
"""
...

@abstractmethod
def recv_checkpoint(
self, src_rank: int, metadata: str, step: int, timeout: timedelta
Expand Down
20 changes: 15 additions & 5 deletions torchft/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AllreduceOptions,
AllToAllOptions,
ReduceScatterOptions,
Work,
)
from torch.futures import Future

Expand Down Expand Up @@ -288,7 +289,7 @@ def allreduce_quantized(
opts: AllreduceOptions | ReduceOp,
process_group: "ProcessGroup",
sync_stream: cuda.Stream | None = None,
) -> Future[list[torch.Tensor]]:
) -> Work:
"""
Performs a quantized all-reduce operation on a list of tensors.

Expand Down Expand Up @@ -379,17 +380,26 @@ def allreduce_quantized(
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
_to_allgather_options(allreduce_opts),
)

# NOTE: This is not supposed to be used with gloo, only with NCCL.
# So we setup the stream dependency here by calling work.wait(),
# which doesn't block the CPU.
#
# The future callback below will run after the work has been
# completed.

work.wait()
fut = work.get_future()

def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
def callback(fut: Future[list[torch.Tensor]]) -> None:
# Dequantize and copy to output buffer.
nonlocal tensors, quantized_tensors, world_size, sync_stream

with torch.cuda.stream(sync_stream):
# Setup stream dependency
fut.wait()
# Dequantize the result back to the original precision
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
return tensors

fut = fut.then(callback)
return fut
fut.add_done_callback(callback)
return work
4 changes: 2 additions & 2 deletions torchft/collectives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def _run_all_reduce_collective(
)
]

fut = allreduce_quantized(tensors, reduce_op, pg)
fut.wait()
work = allreduce_quantized(tensors, reduce_op, pg)
work.wait()

work = pg.allreduce([expected], reduce_op)
work.get_future().wait()
Expand Down
3 changes: 2 additions & 1 deletion torchft/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
def _comm_hook(
state: "Manager", bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
return state.allreduce(bucket.buffer())
work = state.allreduce(bucket.buffer())
return work.get_future()


class PureDistributedDataParallel(nn.Module):
Expand Down
10 changes: 6 additions & 4 deletions torchft/ddp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.distributed_c10d import Work
from torch.futures import Future

from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
from torchft.manager import Manager
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
from torchft.work import _DummyWork


class TestDDP(TestCase):
Expand All @@ -39,14 +41,14 @@ def test_ddp(self) -> None:

call_count = 0

def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
def allreduce(
tensor: torch.Tensor,
) -> Work:
nonlocal call_count

call_count += 1

fut = Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
return _DummyWork(tensor)

manager.allreduce = allreduce

Expand Down
44 changes: 32 additions & 12 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.distributed as dist
from torch import nn, optim
from torch.distributed.distributed_c10d import Work
from torch.distributed.tensor import DTensor
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -85,6 +86,9 @@ def __init__(
self._hooks: List[RemovableHandle] = []

def __enter__(self) -> "LocalSGD":
self._hooks.append(
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
)
# Add optimizer hook which increments the local step counter and syncs if necessary
self._hooks.append(
self._local_optimizer.register_step_post_hook(self._step_post_hook)
Expand All @@ -105,12 +109,20 @@ def __exit__(

return False # Propagate exceptions

def _step_pre_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
self._manager.disallow_checkpoint()

def _step_post_hook(
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
) -> None:
"""
This hook is registered on the optimizer and is called after the optimizer step.
"""
self._manager.allow_checkpoint()

self._local_step += 1
if self._local_step >= self._sync_every:
self.sync()
Expand Down Expand Up @@ -200,7 +212,7 @@ def __init__(
self._outer_optimizer = outer_optimizer

# Stores pending all reduce
self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = []
self._allreduce_work: list[Work] = []
self._stream: Optional[torch.cuda.Stream] = (
torch.cuda.Stream() if torch.cuda.is_available() else None
)
Expand Down Expand Up @@ -368,15 +380,15 @@ def wait(self) -> None:
"""
Waits for the previously scheduled allreduce to finish
"""
if len(self._allreduce_futures) == 0:
if len(self._allreduce_work) == 0:
return

if self._stream is not None:
assert self._stop_event is not None
self._stop_event.synchronize()
self._stop_event = None

self._allreduce_futures = []
self._allreduce_work = []

@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
def prepare_sync(self) -> None:
Expand All @@ -386,7 +398,7 @@ def prepare_sync(self) -> None:
"""
self._save_grads()

assert len(self._allreduce_futures) == 0
assert len(self._allreduce_work) == 0

# Make sure tensors are available to `_stream`
if self._stream is not None:
Expand All @@ -399,7 +411,7 @@ def prepare_sync(self) -> None:
):
self._average_grads()

for work in self._allreduce_futures:
for work in self._allreduce_work:
work.wait()

if self._stream is not None:
Expand All @@ -413,7 +425,7 @@ def perform_sync(self) -> bool:
steps using the outer optimizer.
"""
# Waiting for an allreduce before it has been sent is currently not supported.
assert len(self._allreduce_futures) > 0
assert len(self._allreduce_work) > 0

self.wait()

Expand Down Expand Up @@ -467,7 +479,8 @@ def _allreduce_per_param(self) -> None:
work = self._manager.allreduce(
self._grads[name], should_quantize=self.should_quantize
)
self._allreduce_futures.append(work)

self._allreduce_work.append(work)

def _bucketize_and_allreduce(
self,
Expand Down Expand Up @@ -513,12 +526,19 @@ def _bucketize_and_allreduce(
)

def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
nonlocal bucket_tensors, flat_buffer
for t, pack_offset, numel in bucket_tensors:
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
nonlocal bucket_tensors, flat_buffer
# Setup stream dependency
fut.wait()
for t, pack_offset, numel in bucket_tensors:
t.copy_(
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
)

fut = work.get_future()
fut.add_done_callback(callback)

work = work.then(callback)
self._allreduce_futures.append(work)
self._allreduce_work.append(work)

offset += chunk_size

Expand Down
25 changes: 12 additions & 13 deletions torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import torch
from parameterized import parameterized
from torch import Tensor, nn, optim
from torch.distributed.distributed_c10d import Work
from torch.distributed.tensor import DTensor

from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
from torchft.manager import Manager
from torchft.work import _DummyWork


def create_manager() -> MagicMock:
Expand All @@ -26,6 +28,11 @@ def create_manager() -> MagicMock:

manager.errored.return_value = None

def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work:
return _DummyWork(tensor)

manager.allreduce.side_effect = mock_allreduce

return manager


Expand Down Expand Up @@ -66,7 +73,7 @@ class LocalSGDTest(TestCase):
def test_local_sgd_healthy(self) -> None:
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
manager = create_manager()
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
self.assertEqual(local_sgd._local_step, 0)
inp = torch.rand(2, 3)
Expand Down Expand Up @@ -240,13 +247,9 @@ def test_bucketization_correctness(self) -> None:
manager.should_commit.return_value = True

# Define fake allreduce: multiplies buffer by 2
def fake_allreduce(
tensor: Tensor, should_quantize: bool
) -> torch.futures.Future[Tensor]:
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
tensor.mul_(2)
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
return _DummyWork(tensor)

manager.allreduce.side_effect = fake_allreduce

Expand Down Expand Up @@ -284,13 +287,9 @@ def test_gradient_correctness(self) -> None:
manager.should_commit.return_value = True

# Define fake allreduce: multiplies buffer by 2
def fake_allreduce(
tensor: Tensor, should_quantize: bool
) -> torch.futures.Future[Tensor]:
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
tensor.mul_(2)
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
return _DummyWork(tensor)

manager.allreduce.side_effect = fake_allreduce

Expand Down
Loading