From fef4abc8772d34b9a50a8e7a4d3d37b2c1a8009e Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 25 Jul 2025 18:30:02 -0700 Subject: [PATCH 1/2] use http transport Summary: use http transport instead of pg transport -- pg transport fails to resolve address when running locally --- train_diloco.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train_diloco.py b/train_diloco.py index 0c6b9cf..e207e73 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -34,7 +34,7 @@ ProcessGroupGloo, ProcessGroupNCCL, ) -from torchft.checkpointing.pg_transport import PGTransport +from torchft.checkpointing.http_transport import HTTPTransport from torchft.local_sgd import DiLoCo logging.basicConfig(level=logging.INFO) @@ -67,13 +67,12 @@ def state_dict(): timeout=timedelta(seconds=10), ) if torch.cuda.is_available() and USE_NCCL - else ProcessGroupGloo(timeout=timedelta(seconds=5)) + else ProcessGroupGloo(timeout=timedelta(seconds=10)) ) - transport = PGTransport( - pg, + transport = HTTPTransport( timeout=timedelta(seconds=10), - device=device, + num_chunks=0, ) manager = Manager( From 09bbdea7ca569378f6ac0f9a39161bb9b6421bbb Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 31 Jul 2025 18:55:23 -0700 Subject: [PATCH 2/2] fix stream dependencies in callbacks Summary: - call future.wait in callbacks to make sure the continuation executes after the future has completed - set the stream correctly to execute callback scheduled by bucketized allreduce --- torchft/collectives.py | 2 ++ torchft/local_sgd.py | 11 ++++++++--- torchft/manager.py | 2 ++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/torchft/collectives.py b/torchft/collectives.py index 837fbcd..927309a 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -387,6 +387,8 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]: nonlocal tensors, quantized_tensors, world_size, sync_stream with torch.cuda.stream(sync_stream): + # Setup stream dependency + fut.wait() # Dequantize the result back to the original precision fused_dequantize_from_fp8(tensors, quantized_tensors, world_size) return tensors diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index c7230ee..d0eeccc 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -513,9 +513,14 @@ def _bucketize_and_allreduce( ) def callback(fut: torch.futures.Future[torch.Tensor]) -> None: - nonlocal bucket_tensors, flat_buffer - for t, pack_offset, numel in bucket_tensors: - t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t)) + with torch.cuda.stream(self._stream) if self._stream else nullcontext(): + nonlocal bucket_tensors, flat_buffer + # Setup stream dependency + fut.wait() + for t, pack_offset, numel in bucket_tensors: + t.copy_( + flat_buffer[pack_offset : pack_offset + numel].view_as(t) + ) work = work.then(callback) self._allreduce_futures.append(work) diff --git a/torchft/manager.py b/torchft/manager.py index e01a965..09100c3 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -403,6 +403,8 @@ def callback( # change the stream to avoid making the callback stream # dependent on process group stream running the allreduce with torch.cuda.stream(stream) if stream is not None else nullcontext(): + # Setup stream dependency + fut.wait() fut.value() tensor /= num_participants