Skip to content

Commit 22b8fa1

Browse files
authored
fix compute/communication overlap for gloo (#240)
* remove recovery form regression test Summary: - we currently do some validation on the training in the regression test - the force recovery on first step interferes with this because it makes the test non determinstic, particularly because after the recovery, replica takes non deterministic number of steps that makes the gradients non determinstic - to fix this, perform a quorum inside fake training loop for the regression test before doing any training - we also need to increase manager step count by 2, so we do 2 should_commit, because we have 2 fragments and we're testing numerics as if we started from step 0 -- starting from step 2 gives us the same sync schedule for fragments as starting from step 0 * fix compute/communication overlap for gloo Summary: - we current wait for pg work's future when preparing for a fragment - if we use gloo, this blocks the cpu - move the wait call to when we perform the actual sync of the fragment
1 parent f121e4a commit 22b8fa1

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

torchft/local_sgd.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -400,13 +400,6 @@ def prepare_sync(self) -> None:
400400
):
401401
self._average_grads()
402402

403-
for work in self._allreduce_work:
404-
work.wait()
405-
406-
if self._stream is not None:
407-
self._stop_event = torch.cuda.Event()
408-
self._stop_event.record()
409-
410403
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
411404
def perform_sync(self) -> bool:
412405
"""
@@ -416,6 +409,18 @@ def perform_sync(self) -> bool:
416409
# Waiting for an allreduce before it has been sent is currently not supported.
417410
assert len(self._allreduce_work) > 0
418411

412+
with (
413+
torch.cuda.stream(self._stream)
414+
if self._stream is not None
415+
else nullcontext()
416+
):
417+
for work in self._allreduce_work:
418+
work.wait()
419+
420+
if self._stream is not None:
421+
self._stop_event = torch.cuda.Event()
422+
self._stop_event.record()
423+
419424
self.wait()
420425

421426
# save the parameters so they can be used for merging

0 commit comments

Comments
 (0)