diff --git a/test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_failure_recovery_0.json b/test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_failure_recovery_0.json index 1111b06..40b6732 100644 --- a/test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_failure_recovery_0.json +++ b/test_fixtures/torchft.diloco_regression_test.DiLoCoMockedUpdateTest.test_diloco_mocked_failure_recovery_0.json @@ -193,6 +193,78 @@ -53.0 ] ] + }, + "16": { + "layers.0.weight": [ + [ + -61.0 + ] + ], + "layers.1.weight": [ + [ + -55.0 + ] + ] + }, + "17": { + "layers.0.weight": [ + [ + -63.0 + ] + ], + "layers.1.weight": [ + [ + -57.0 + ] + ] + }, + "18": { + "layers.0.weight": [ + [ + -65.0 + ] + ], + "layers.1.weight": [ + [ + -71.0 + ] + ] + }, + "19": { + "layers.0.weight": [ + [ + -67.0 + ] + ], + "layers.1.weight": [ + [ + -73.0 + ] + ] + }, + "20": { + "layers.0.weight": [ + [ + -69.0 + ] + ], + "layers.1.weight": [ + [ + -75.0 + ] + ] + }, + "21": { + "layers.0.weight": [ + [ + -83.0 + ] + ], + "layers.1.weight": [ + [ + -77.0 + ] + ] } }, "global_parameter_history": { @@ -255,6 +327,30 @@ -47.0 ] ] + }, + "15": { + "layers.0.weight": [ + [ + -59.0 + ] + ], + "layers.1.weight": [ + [ + -47.0 + ] + ] + }, + "18": { + "layers.0.weight": [ + [ + -59.0 + ] + ], + "layers.1.weight": [ + [ + -71.0 + ] + ] } } } @@ -381,6 +477,78 @@ -53.0 ] ] + }, + "10": { + "layers.0.weight": [ + [ + -61.0 + ] + ], + "layers.1.weight": [ + [ + -55.0 + ] + ] + }, + "11": { + "layers.0.weight": [ + [ + -63.0 + ] + ], + "layers.1.weight": [ + [ + -57.0 + ] + ] + }, + "12": { + "layers.0.weight": [ + [ + -65.0 + ] + ], + "layers.1.weight": [ + [ + -71.0 + ] + ] + }, + "13": { + "layers.0.weight": [ + [ + -67.0 + ] + ], + "layers.1.weight": [ + [ + -73.0 + ] + ] + }, + "14": { + "layers.0.weight": [ + [ + -69.0 + ] + ], + "layers.1.weight": [ + [ + -75.0 + ] + ] + }, + "15": { + "layers.0.weight": [ + [ + -83.0 + ] + ], + "layers.1.weight": [ + [ + -77.0 + ] + ] } }, "global_parameter_history": { @@ -419,6 +587,30 @@ -47.0 ] ] + }, + "9": { + "layers.0.weight": [ + [ + -59.0 + ] + ], + "layers.1.weight": [ + [ + -47.0 + ] + ] + }, + "12": { + "layers.0.weight": [ + [ + -59.0 + ] + ], + "layers.1.weight": [ + [ + -71.0 + ] + ] } } } diff --git a/torchft/diloco_regression_test.py b/torchft/diloco_regression_test.py index 99ee931..a1cb7c9 100644 --- a/torchft/diloco_regression_test.py +++ b/torchft/diloco_regression_test.py @@ -3,6 +3,7 @@ import json import logging import os +import threading from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import ExitStack from datetime import timedelta @@ -141,6 +142,7 @@ def __init__( diloco_args: dict[str, Any], inner_lr: float = 1, outer_lr: float = 2, + quorum_barrier: Optional[threading.Barrier] = None, ) -> None: self.inner_lr = inner_lr self.outer_lr = outer_lr @@ -150,6 +152,8 @@ def __init__( rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args ) + self.quorum_barrier = quorum_barrier + def setup_model(self) -> MockModel: """Set up the mock model and move it to the device.""" model = MockModel(in_dim=1, out_dim=1, n_layers=self.n_fragments) @@ -186,6 +190,14 @@ def train_loop(self) -> Dict[str, Any]: backup_device=self.device, **self.diloco_args, ) as self.diloco: + if self.quorum_barrier is not None: + self.manager.start_quorum() + self.manager.wait_quorum() + assert self.quorum_barrier is not None + self.quorum_barrier.wait() + assert self.manager.should_commit() + assert self.manager.should_commit() + local_step = 0 manager_steps = set() while True: @@ -197,7 +209,7 @@ def train_loop(self) -> Dict[str, Any]: manager_curr_step = self.manager.current_step() - if manager_curr_step == 5: + if manager_curr_step == 7: break if manager_curr_step not in manager_steps: @@ -248,6 +260,7 @@ def mock_diloco_train_loop( model_state_dict = train_loop_args.get("model_state_dict", {}) n_fragments = train_loop_args.get("n_fragments", 1) diloco_args = train_loop_args.get("diloco_args", {}) + quorum_barrier = train_loop_args.get("quorum_barrier", None) with ExitStack() as stack: trainer = MockDiLoCoTrainer( @@ -258,6 +271,7 @@ def mock_diloco_train_loop( model_state_dict, n_fragments, diloco_args, + quorum_barrier=quorum_barrier, ) stack.callback(trainer.manager.shutdown) return trainer.train_loop() @@ -304,6 +318,7 @@ def test_diloco_mocked_updates( # Create a proper state_dict for the model to avoid load_state_dict errors temp_model = MockModel(in_dim=1, out_dim=1, n_layers=n_fragments) model_state_dict = temp_model.state_dict() + quorum_barrier = threading.Barrier(num_replicas) with ThreadPoolExecutor(max_workers=num_replicas) as executor: for replica_id in range(num_replicas): @@ -316,6 +331,7 @@ def test_diloco_mocked_updates( train_loop=mock_diloco_train_loop, use_cuda=use_cuda, train_loop_args={ + "quorum_barrier": quorum_barrier, "n_fragments": n_fragments, "model_state_dict": model_state_dict, "diloco_args": { diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 957680e..69f7130 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -400,13 +400,6 @@ def prepare_sync(self) -> None: ): self._average_grads() - for work in self._allreduce_work: - work.wait() - - if self._stream is not None: - self._stop_event = torch.cuda.Event() - self._stop_event.record() - @torch.profiler.record_function("torchft::local_sgd::perform_sync") def perform_sync(self) -> bool: """ @@ -416,6 +409,18 @@ def perform_sync(self) -> bool: # Waiting for an allreduce before it has been sent is currently not supported. assert len(self._allreduce_work) > 0 + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + for work in self._allreduce_work: + work.wait() + + if self._stream is not None: + self._stop_event = torch.cuda.Event() + self._stop_event.record() + self.wait() # save the parameters so they can be used for merging