diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 937e34c8..589dada3 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -26,6 +26,7 @@ jobs: pip install lintrunner lintrunner-adapters lintrunner init + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 pip install .[dev] -v - name: Run lintrunner run: | diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 3cb05d42..d1f5bd55 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -15,11 +15,7 @@ jobs: - runs-on: "linux.2xlarge" gpu-arch-type: "cpu" gpu-arch-version: "" - torch-version: "stable" - - runs-on: "linux.g5.12xlarge.nvidia.gpu" - gpu-arch-type: "cuda" - gpu-arch-version: "12.4" - torch-version: "stable" + torch-version: "nightly" - runs-on: "linux.g5.12xlarge.nvidia.gpu" gpu-arch-type: "cuda" gpu-arch-version: "12.4" diff --git a/torchft/collectives.py b/torchft/collectives.py index af95cbbb..cd84b0b9 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -162,7 +162,7 @@ def reduce_scatter_quantized( opts: ReduceScatterOptions | ReduceOp, process_group: "ProcessGroup", sync_stream: cuda.Stream | None = None, -) -> Future[None]: +) -> Work: """ Performs a quantized reduce-scatter operation on a list of tensors. @@ -196,10 +196,10 @@ def reduce_scatter_quantized( """ if isinstance(opts, ReduceOp): - reducescatter_opts = ReduceScatterOptions() + reducescatter_opts: ReduceScatterOptions = ReduceScatterOptions() reducescatter_opts.reduceOp = opts else: - reducescatter_opts = opts + reducescatter_opts: ReduceScatterOptions = opts # Check if the reduceOp is AVG or SUM if reducescatter_opts.reduceOp not in { @@ -211,15 +211,15 @@ def reduce_scatter_quantized( f"for quantized reduce-scatter, only AVG and SUM are supported" ) - rank = process_group.rank() - world_size = process_group.size() + rank: int = process_group.rank() + world_size: int = process_group.size() reduce_output_sizes = [ torch.Size((s[0] // world_size, *s[1:])) for s in get_padded_sizes(inputs, world_size) ] reduce_output_numels = [s.numel() for s in reduce_output_sizes] - reduce_outputs = [ + reduce_outputs: list[torch.Tensor] = [ o.view(s) for o, s in zip( output.split(reduce_output_numels), @@ -240,48 +240,51 @@ def reduce_scatter_quantized( quantized_inputs = fused_quantize_into_fp8(inputs, world_size) # Allocate output tensor where all-reduce results will be stored - quantized_inputs_out = torch.zeros_like(quantized_inputs) + quantized_inputs_out: torch.Tensor = torch.zeros_like(quantized_inputs) # Collect chunks and their scales from other ranks - process_group.alltoall_base( + work = process_group.alltoall_base( quantized_inputs_out.view(world_size, -1), quantized_inputs.view(world_size, -1), [], [], _to_alltoall_options(reducescatter_opts), - ).wait() - - # Reduce chunks locally in higher precision after dequantization. - # The output is again quantized. - fused_reduce_fp8( - inputs, - quantized_inputs_out, - world_size, - rank, - reducescatter_opts.reduceOp, ) + work.wait() - # Get view into the output tensor that corresponds to the - # current rank - quantized_reduce_scatter = ( - quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0) - ) - # Dequantize the result back to the original precision for - # the current rank - fused_dequantize_from_fp8( - reduce_outputs, - quantized_reduce_scatter, - 1, - ) + fut = work.get_future() - # pyre-ignore[29] - return _QuantizedOpFuture( - sync_stream, - [ - quantized_inputs, - quantized_inputs_out, - ], - [output], - ) + def callback(fut: Future[list[torch.Tensor]]) -> None: + nonlocal inputs, quantized_inputs_out, world_size, sync_stream, rank, reduce_outputs, reducescatter_opts + + with torch.cuda.stream(sync_stream): + # Setup stream dependency + fut.wait() + # Reduce chunks locally in higher precision after dequantization. + # The output is again quantized. + fused_reduce_fp8( + inputs, + quantized_inputs_out, + world_size, + rank, + reducescatter_opts.reduceOp, + ) + + # Get view into the output tensor that corresponds to the + # current rank + quantized_reduce_scatter = ( + quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0) + ) + # Dequantize the result back to the original precision for + # the current rank + fused_dequantize_from_fp8( + reduce_outputs, + quantized_reduce_scatter, + 1, + ) + + fut.add_done_callback(callback) + + return work def allreduce_quantized( diff --git a/torchft/collectives_test.py b/torchft/collectives_test.py index 6660abe2..b73a18b2 100644 --- a/torchft/collectives_test.py +++ b/torchft/collectives_test.py @@ -141,8 +141,8 @@ def _run_reduce_scatter_collective( opts = ReduceScatterOptions() opts.reduceOp = reduce_op - fut = reduce_scatter_quantized(actual_output, tensors, opts, pg) - fut.wait() + work = reduce_scatter_quantized(actual_output, tensors, opts, pg) + work.get_future().wait() padded_sizes = get_padded_sizes(tensors, world_size) padded_numel = sum(s.numel() for s in padded_sizes) diff --git a/torchft/futures.py b/torchft/futures.py index 52bb96ef..c20ad65e 100644 --- a/torchft/futures.py +++ b/torchft/futures.py @@ -148,7 +148,6 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]: loop = self._maybe_start_event_loop() - # pyre-fixme[29]: Future is not a function timed_fut: Future[T] = Future() handle: _TimerHandle = _TimerHandle() loop.call_soon_threadsafe( diff --git a/torchft/futures_test.py b/torchft/futures_test.py index cdc4cb1c..59ca73d5 100644 --- a/torchft/futures_test.py +++ b/torchft/futures_test.py @@ -24,38 +24,32 @@ def tearDown(self) -> None: _TIMEOUT_MANAGER._watchdog_interval = self._original_watchdog_interval def test_future_wait(self) -> None: - # pyre-fixme[29]: Future is not a function fut = Future() with self.assertRaisesRegex(TimeoutError, "future did not complete within"): future_wait(fut, timeout=timedelta(seconds=0.01)) - # pyre-fixme[29]: Future is not a function fut = Future() fut.set_result(1) self.assertEqual(future_wait(fut, timeout=timedelta(seconds=1.0)), 1) - # pyre-fixme[29]: Future is not a function fut = Future() fut.set_exception(RuntimeError("test")) with self.assertRaisesRegex(RuntimeError, "test"): future_wait(fut, timeout=timedelta(seconds=1.0)) def test_future_timeout(self) -> None: - # pyre-fixme[29]: Future is not a function fut = Future() timed_fut = future_timeout(fut, timeout=timedelta(seconds=0.01)) with self.assertRaisesRegex(TimeoutError, "future did not complete within"): timed_fut.wait() def test_future_timeout_result(self) -> None: - # pyre-fixme[29]: Future is not a function fut = Future() timed_fut = future_timeout(fut, timeout=timedelta(seconds=10)) fut.set_result(1) self.assertEqual(timed_fut.wait(), 1) def test_future_timeout_exception(self) -> None: - # pyre-fixme[29]: Future is not a function fut = Future() timed_fut = future_timeout(fut, timeout=timedelta(seconds=10)) fut.set_exception(RuntimeError("test")) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index d0c81a0f..b5616bba 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -401,7 +401,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: self.assertFalse(manager._errored) - bad_fut = torch.futures.Future() # pyre-fixme[29]: not a function + bad_fut = torch.futures.Future() bad_fut.set_exception(RuntimeError("injected failure")) manager._pg.allreduce.return_value.get_future.return_value = bad_fut manager.allreduce(torch.tensor([1.0])).wait() @@ -542,7 +542,7 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None: self.assertIsNone(manager.errored()) - fut = torch.futures.Future() # pyre-fixme[29]: not a function + fut = torch.futures.Future() wrapped_fut = manager.wrap_future(fut, 2) self.assertIsNone(manager.errored()) @@ -559,7 +559,7 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None: self.assertFalse(manager.errored()) - fut = torch.futures.Future() # pyre-fixme[29]: not a function + fut = torch.futures.Future() wrapped_fut = manager.wrap_future(fut, 2) wrapped_fut.wait() error = manager.errored() diff --git a/torchft/process_group.py b/torchft/process_group.py index 4750dc97..bfbfe561 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -183,6 +183,7 @@ def alltoall_base( """ raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def barrier(self, opts: BarrierOptions) -> Work: """ Synchronizes all processes. @@ -496,7 +497,7 @@ def alltoall_base( opts, ) - def barrier(self, opts: BarrierOptions) -> Work: + def barrier(self, opts: Optional[BarrierOptions] = None) -> Work: with self._run_context(): return self._wrap_work(self.parent.barrier(self._opts_hook(opts)), opts) @@ -866,7 +867,7 @@ def alltoall_base( self._work.append(res) return res - def barrier(self, opts: BarrierOptions) -> Work: + def barrier(self, opts: Optional[BarrierOptions] = None) -> Work: return _DummyWork(None) def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: @@ -1497,7 +1498,7 @@ def _get_future( self, op_id: int, stream: Optional[torch.cuda.Stream] ) -> Future[object]: with self._futures_lock: - fut = Future() # pyre-fixme[29]: is not a function + fut = Future() self._futures[op_id] = _FutureMetadata(future=fut, stream=stream) assert self._pipe is not None self._pipe.send(("future", op_id)) @@ -1629,7 +1630,7 @@ def alltoall_base( opts, ) - def barrier(self, opts: BarrierOptions) -> Work: + def barrier(self, opts: Optional[BarrierOptions] = None) -> Work: return self._run_func("barrier", opts) def broadcast( diff --git a/torchft/work.py b/torchft/work.py index 7211c0d1..8cb056a8 100644 --- a/torchft/work.py +++ b/torchft/work.py @@ -10,7 +10,6 @@ class _DummyWork(dist._Work): def __init__(self, result: object) -> None: super().__init__() self.result_ = result - # pyre-fixme[29]: Future is not a function self.future_: torch.futures.Future[object] = torch.futures.Future() self.future_.set_result(result)