Skip to content

only use nightly pytorch in ci #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
pip install lintrunner lintrunner-adapters
lintrunner init

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install .[dev] -v
- name: Run lintrunner
run: |
Expand Down
6 changes: 1 addition & 5 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ jobs:
- runs-on: "linux.2xlarge"
gpu-arch-type: "cpu"
gpu-arch-version: ""
torch-version: "stable"
- runs-on: "linux.g5.12xlarge.nvidia.gpu"
gpu-arch-type: "cuda"
gpu-arch-version: "12.4"
torch-version: "stable"
torch-version: "nightly"
- runs-on: "linux.g5.12xlarge.nvidia.gpu"
gpu-arch-type: "cuda"
gpu-arch-version: "12.4"
Expand Down
81 changes: 42 additions & 39 deletions torchft/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def reduce_scatter_quantized(
opts: ReduceScatterOptions | ReduceOp,
process_group: "ProcessGroup",
sync_stream: cuda.Stream | None = None,
) -> Future[None]:
) -> Work:
"""
Performs a quantized reduce-scatter operation on a list of tensors.

Expand Down Expand Up @@ -196,10 +196,10 @@ def reduce_scatter_quantized(
"""

if isinstance(opts, ReduceOp):
reducescatter_opts = ReduceScatterOptions()
reducescatter_opts: ReduceScatterOptions = ReduceScatterOptions()
reducescatter_opts.reduceOp = opts
else:
reducescatter_opts = opts
reducescatter_opts: ReduceScatterOptions = opts

# Check if the reduceOp is AVG or SUM
if reducescatter_opts.reduceOp not in {
Expand All @@ -211,15 +211,15 @@ def reduce_scatter_quantized(
f"for quantized reduce-scatter, only AVG and SUM are supported"
)

rank = process_group.rank()
world_size = process_group.size()
rank: int = process_group.rank()
world_size: int = process_group.size()

reduce_output_sizes = [
torch.Size((s[0] // world_size, *s[1:]))
for s in get_padded_sizes(inputs, world_size)
]
reduce_output_numels = [s.numel() for s in reduce_output_sizes]
reduce_outputs = [
reduce_outputs: list[torch.Tensor] = [
o.view(s)
for o, s in zip(
output.split(reduce_output_numels),
Expand All @@ -240,48 +240,51 @@ def reduce_scatter_quantized(
quantized_inputs = fused_quantize_into_fp8(inputs, world_size)

# Allocate output tensor where all-reduce results will be stored
quantized_inputs_out = torch.zeros_like(quantized_inputs)
quantized_inputs_out: torch.Tensor = torch.zeros_like(quantized_inputs)
# Collect chunks and their scales from other ranks
process_group.alltoall_base(
work = process_group.alltoall_base(
quantized_inputs_out.view(world_size, -1),
quantized_inputs.view(world_size, -1),
[],
[],
_to_alltoall_options(reducescatter_opts),
).wait()

# Reduce chunks locally in higher precision after dequantization.
# The output is again quantized.
fused_reduce_fp8(
inputs,
quantized_inputs_out,
world_size,
rank,
reducescatter_opts.reduceOp,
)
work.wait()

# Get view into the output tensor that corresponds to the
# current rank
quantized_reduce_scatter = (
quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0)
)
# Dequantize the result back to the original precision for
# the current rank
fused_dequantize_from_fp8(
reduce_outputs,
quantized_reduce_scatter,
1,
)
fut = work.get_future()

# pyre-ignore[29]
return _QuantizedOpFuture(
sync_stream,
[
quantized_inputs,
quantized_inputs_out,
],
[output],
)
def callback(fut: Future[list[torch.Tensor]]) -> None:
nonlocal inputs, quantized_inputs_out, world_size, sync_stream, rank, reduce_outputs, reducescatter_opts

with torch.cuda.stream(sync_stream):
# Setup stream dependency
fut.wait()
# Reduce chunks locally in higher precision after dequantization.
# The output is again quantized.
fused_reduce_fp8(
inputs,
quantized_inputs_out,
world_size,
rank,
reducescatter_opts.reduceOp,
)

# Get view into the output tensor that corresponds to the
# current rank
quantized_reduce_scatter = (
quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0)
)
# Dequantize the result back to the original precision for
# the current rank
fused_dequantize_from_fp8(
reduce_outputs,
quantized_reduce_scatter,
1,
)

fut.add_done_callback(callback)

return work


def allreduce_quantized(
Expand Down
4 changes: 2 additions & 2 deletions torchft/collectives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def _run_reduce_scatter_collective(
opts = ReduceScatterOptions()
opts.reduceOp = reduce_op

fut = reduce_scatter_quantized(actual_output, tensors, opts, pg)
fut.wait()
work = reduce_scatter_quantized(actual_output, tensors, opts, pg)
work.get_future().wait()

padded_sizes = get_padded_sizes(tensors, world_size)
padded_numel = sum(s.numel() for s in padded_sizes)
Expand Down
1 change: 0 additions & 1 deletion torchft/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:

loop = self._maybe_start_event_loop()

# pyre-fixme[29]: Future is not a function
timed_fut: Future[T] = Future()
handle: _TimerHandle = _TimerHandle()
loop.call_soon_threadsafe(
Expand Down
6 changes: 0 additions & 6 deletions torchft/futures_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,32 @@ def tearDown(self) -> None:
_TIMEOUT_MANAGER._watchdog_interval = self._original_watchdog_interval

def test_future_wait(self) -> None:
# pyre-fixme[29]: Future is not a function
fut = Future()
with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
future_wait(fut, timeout=timedelta(seconds=0.01))

# pyre-fixme[29]: Future is not a function
fut = Future()
fut.set_result(1)
self.assertEqual(future_wait(fut, timeout=timedelta(seconds=1.0)), 1)

# pyre-fixme[29]: Future is not a function
fut = Future()
fut.set_exception(RuntimeError("test"))
with self.assertRaisesRegex(RuntimeError, "test"):
future_wait(fut, timeout=timedelta(seconds=1.0))

def test_future_timeout(self) -> None:
# pyre-fixme[29]: Future is not a function
fut = Future()
timed_fut = future_timeout(fut, timeout=timedelta(seconds=0.01))
with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
timed_fut.wait()

def test_future_timeout_result(self) -> None:
# pyre-fixme[29]: Future is not a function
fut = Future()
timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
fut.set_result(1)
self.assertEqual(timed_fut.wait(), 1)

def test_future_timeout_exception(self) -> None:
# pyre-fixme[29]: Future is not a function
fut = Future()
timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
fut.set_exception(RuntimeError("test"))
Expand Down
6 changes: 3 additions & 3 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:

self.assertFalse(manager._errored)

bad_fut = torch.futures.Future() # pyre-fixme[29]: not a function
bad_fut = torch.futures.Future()
bad_fut.set_exception(RuntimeError("injected failure"))
manager._pg.allreduce.return_value.get_future.return_value = bad_fut
manager.allreduce(torch.tensor([1.0])).wait()
Expand Down Expand Up @@ -542,7 +542,7 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:

self.assertIsNone(manager.errored())

fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut = torch.futures.Future()
wrapped_fut = manager.wrap_future(fut, 2)
self.assertIsNone(manager.errored())

Expand All @@ -559,7 +559,7 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:

self.assertFalse(manager.errored())

fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut = torch.futures.Future()
wrapped_fut = manager.wrap_future(fut, 2)
wrapped_fut.wait()
error = manager.errored()
Expand Down
9 changes: 5 additions & 4 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def alltoall_base(
"""
raise NotImplementedError("not implemented")

# pyre-fixme[14]: inconsistent override
def barrier(self, opts: BarrierOptions) -> Work:
"""
Synchronizes all processes.
Expand Down Expand Up @@ -496,7 +497,7 @@ def alltoall_base(
opts,
)

def barrier(self, opts: BarrierOptions) -> Work:
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
with self._run_context():
return self._wrap_work(self.parent.barrier(self._opts_hook(opts)), opts)

Expand Down Expand Up @@ -866,7 +867,7 @@ def alltoall_base(
self._work.append(res)
return res

def barrier(self, opts: BarrierOptions) -> Work:
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
return _DummyWork(None)

def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
Expand Down Expand Up @@ -1497,7 +1498,7 @@ def _get_future(
self, op_id: int, stream: Optional[torch.cuda.Stream]
) -> Future[object]:
with self._futures_lock:
fut = Future() # pyre-fixme[29]: is not a function
fut = Future()
self._futures[op_id] = _FutureMetadata(future=fut, stream=stream)
assert self._pipe is not None
self._pipe.send(("future", op_id))
Expand Down Expand Up @@ -1629,7 +1630,7 @@ def alltoall_base(
opts,
)

def barrier(self, opts: BarrierOptions) -> Work:
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
return self._run_func("barrier", opts)

def broadcast(
Expand Down
1 change: 0 additions & 1 deletion torchft/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class _DummyWork(dist._Work):
def __init__(self, result: object) -> None:
super().__init__()
self.result_ = result
# pyre-fixme[29]: Future is not a function
self.future_: torch.futures.Future[object] = torch.futures.Future()
self.future_.set_result(result)

Expand Down
Loading