Skip to content

fix compute/communication overlap for gloo #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,78 @@
-53.0
]
]
},
"16": {
"layers.0.weight": [
[
-61.0
]
],
"layers.1.weight": [
[
-55.0
]
]
},
"17": {
"layers.0.weight": [
[
-63.0
]
],
"layers.1.weight": [
[
-57.0
]
]
},
"18": {
"layers.0.weight": [
[
-65.0
]
],
"layers.1.weight": [
[
-71.0
]
]
},
"19": {
"layers.0.weight": [
[
-67.0
]
],
"layers.1.weight": [
[
-73.0
]
]
},
"20": {
"layers.0.weight": [
[
-69.0
]
],
"layers.1.weight": [
[
-75.0
]
]
},
"21": {
"layers.0.weight": [
[
-83.0
]
],
"layers.1.weight": [
[
-77.0
]
]
}
},
"global_parameter_history": {
Expand Down Expand Up @@ -255,6 +327,30 @@
-47.0
]
]
},
"15": {
"layers.0.weight": [
[
-59.0
]
],
"layers.1.weight": [
[
-47.0
]
]
},
"18": {
"layers.0.weight": [
[
-59.0
]
],
"layers.1.weight": [
[
-71.0
]
]
}
}
}
Expand Down Expand Up @@ -381,6 +477,78 @@
-53.0
]
]
},
"10": {
"layers.0.weight": [
[
-61.0
]
],
"layers.1.weight": [
[
-55.0
]
]
},
"11": {
"layers.0.weight": [
[
-63.0
]
],
"layers.1.weight": [
[
-57.0
]
]
},
"12": {
"layers.0.weight": [
[
-65.0
]
],
"layers.1.weight": [
[
-71.0
]
]
},
"13": {
"layers.0.weight": [
[
-67.0
]
],
"layers.1.weight": [
[
-73.0
]
]
},
"14": {
"layers.0.weight": [
[
-69.0
]
],
"layers.1.weight": [
[
-75.0
]
]
},
"15": {
"layers.0.weight": [
[
-83.0
]
],
"layers.1.weight": [
[
-77.0
]
]
}
},
"global_parameter_history": {
Expand Down Expand Up @@ -419,6 +587,30 @@
-47.0
]
]
},
"9": {
"layers.0.weight": [
[
-59.0
]
],
"layers.1.weight": [
[
-47.0
]
]
},
"12": {
"layers.0.weight": [
[
-59.0
]
],
"layers.1.weight": [
[
-71.0
]
]
}
}
}
Expand Down
18 changes: 17 additions & 1 deletion torchft/diloco_regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack
from datetime import timedelta
Expand Down Expand Up @@ -141,6 +142,7 @@ def __init__(
diloco_args: dict[str, Any],
inner_lr: float = 1,
outer_lr: float = 2,
quorum_barrier: Optional[threading.Barrier] = None,
) -> None:
self.inner_lr = inner_lr
self.outer_lr = outer_lr
Expand All @@ -150,6 +152,8 @@ def __init__(
rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args
)

self.quorum_barrier = quorum_barrier

def setup_model(self) -> MockModel:
"""Set up the mock model and move it to the device."""
model = MockModel(in_dim=1, out_dim=1, n_layers=self.n_fragments)
Expand Down Expand Up @@ -186,6 +190,14 @@ def train_loop(self) -> Dict[str, Any]:
backup_device=self.device,
**self.diloco_args,
) as self.diloco:
if self.quorum_barrier is not None:
self.manager.start_quorum()
self.manager.wait_quorum()
assert self.quorum_barrier is not None
self.quorum_barrier.wait()
assert self.manager.should_commit()
assert self.manager.should_commit()

local_step = 0
manager_steps = set()
while True:
Expand All @@ -197,7 +209,7 @@ def train_loop(self) -> Dict[str, Any]:

manager_curr_step = self.manager.current_step()

if manager_curr_step == 5:
if manager_curr_step == 7:
break

if manager_curr_step not in manager_steps:
Expand Down Expand Up @@ -248,6 +260,7 @@ def mock_diloco_train_loop(
model_state_dict = train_loop_args.get("model_state_dict", {})
n_fragments = train_loop_args.get("n_fragments", 1)
diloco_args = train_loop_args.get("diloco_args", {})
quorum_barrier = train_loop_args.get("quorum_barrier", None)

with ExitStack() as stack:
trainer = MockDiLoCoTrainer(
Expand All @@ -258,6 +271,7 @@ def mock_diloco_train_loop(
model_state_dict,
n_fragments,
diloco_args,
quorum_barrier=quorum_barrier,
)
stack.callback(trainer.manager.shutdown)
return trainer.train_loop()
Expand Down Expand Up @@ -304,6 +318,7 @@ def test_diloco_mocked_updates(
# Create a proper state_dict for the model to avoid load_state_dict errors
temp_model = MockModel(in_dim=1, out_dim=1, n_layers=n_fragments)
model_state_dict = temp_model.state_dict()
quorum_barrier = threading.Barrier(num_replicas)

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id in range(num_replicas):
Expand All @@ -316,6 +331,7 @@ def test_diloco_mocked_updates(
train_loop=mock_diloco_train_loop,
use_cuda=use_cuda,
train_loop_args={
"quorum_barrier": quorum_barrier,
"n_fragments": n_fragments,
"model_state_dict": model_state_dict,
"diloco_args": {
Expand Down
19 changes: 12 additions & 7 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,6 @@ def prepare_sync(self) -> None:
):
self._average_grads()

for work in self._allreduce_work:
work.wait()

if self._stream is not None:
self._stop_event = torch.cuda.Event()
self._stop_event.record()

@torch.profiler.record_function("torchft::local_sgd::perform_sync")
def perform_sync(self) -> bool:
"""
Expand All @@ -416,6 +409,18 @@ def perform_sync(self) -> bool:
# Waiting for an allreduce before it has been sent is currently not supported.
assert len(self._allreduce_work) > 0

with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
for work in self._allreduce_work:
work.wait()

if self._stream is not None:
self._stop_event = torch.cuda.Event()
self._stop_event.record()

self.wait()

# save the parameters so they can be used for merging
Expand Down