diff --git a/.gitignore b/.gitignore index 34f9a2a5..6d8f13f3 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,19 @@ slurm*.out examples/gemm/results/* asm/ -*.img \ No newline at end of file +*.img + +.cache/ +.local/ +.triton/ +.pytest_cache/ +.ruff_cache/ +__pycache__/ +*.pyc +*.pyo +*.pyd +*.pyw +*.pyz +*.pywz +*.pyzw +*.pyzwz \ No newline at end of file diff --git a/benchmark/ccl/all_gather/benchmark.py b/benchmark/ccl/all_gather/benchmark.py new file mode 100644 index 00000000..714cfa8f --- /dev/null +++ b/benchmark/ccl/all_gather/benchmark.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris-ccl all-gather collective operation. + +This benchmark showcases the all-gather collective and reports achieved bandwidth. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ccl import Config + +# Conditional import for Gluon +try: + import iris.experimental.iris_gluon as iris_gluon + + GLUON_AVAILABLE = True +except ImportError: + GLUON_AVAILABLE = False + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark all-gather collective operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in input tensors") + parser.add_argument("-n", type=int, default=16384, help="Number of columns in input tensors") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-gather kernel") + parser.add_argument( + "--benchmark_rccl", + action="store_true", + help="Also benchmark PyTorch RCCL (all_gather_into_tensor) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=None, help="Block size for M dimension tiling") + parser.add_argument("--block_size_n", type=int, default=None, help="Block size for N dimension tiling") + parser.add_argument("--swizzle_size", type=int, default=None, help="Number of tiles to swizzle together") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Use Gluon if requested and available + if args.get("use_gluon", False): + if not GLUON_AVAILABLE: + raise RuntimeError("Gluon is not available. Install Triton with Gluon support or remove --use_gluon flag") + shmem = iris_gluon.iris(args["heap_size"]) + else: + shmem = iris.iris(args["heap_size"]) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + + # Create config with optional block size parameters + config_kwargs = {"comm_sms": args["comm_sms"]} + if args["block_size_m"] is not None: + config_kwargs["block_size_m"] = args["block_size_m"] + if args["block_size_n"] is not None: + config_kwargs["block_size_n"] = args["block_size_n"] + if args["swizzle_size"] is not None: + config_kwargs["swizzle_size"] = args["swizzle_size"] + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + if args.get("use_gluon", False): + config_kwargs["use_gluon"] = True + + config = Config(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export config values to JSON (use actual values from config, including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("swizzle_size", config.swizzle_size) + json_writer.add_field("num_xcds", config.num_xcds) + json_writer.add_field("use_gluon", config.use_gluon) + + # Create input and output tensors for all-gather + # Input: each rank has (M, N) tensor + # Output: (world_size * M, N) - concatenated along dimension 0 + # Note: Must use shmem.zeros() to allocate on Iris symmetric heap for iris.put() compatibility + input_tensor = shmem.zeros((M, N), dtype=datatype) + output_tensor = shmem.zeros((world_size * M, N), dtype=datatype) + expected_tensor = shmem.zeros((world_size * M, N), dtype=datatype) + + # Fill input with deterministic values + val = float(rank + 1) + input_tensor.fill_(val) + + # Expected output: each rank's input appears at output[rank * M : (rank + 1) * M, :] + for r in range(world_size): + expected_val = float(r + 1) + expected_tensor[r * M : (r + 1) * M, :] = expected_val + + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "all_gather": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + def run_experiment(): + nonlocal kernel_timing + shmem.barrier() + + torch.cuda.nvtx.range_push("All-Gather") + with torch.cuda.stream(comm_stream): + kernel_timing["all_gather"]["start_event"].record() + shmem.ccl.all_gather(output_tensor, input_tensor, config=config, async_op=False) + kernel_timing["all_gather"]["end_event"].record() + kernel_timing["all_gather"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["all_gather"]["start_event"].elapsed_time(kernel_timing["all_gather"]["end_event"]) + kernel_timing["all_gather"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + output_tensor.zero_() + shmem.barrier() + + # Reinitialize input data + val = float(rank + 1) + input_tensor.fill_(val) + shmem.barrier() + + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + atol = 1e-3 if datatype == torch.float16 else 1e-5 + success = torch.allclose(output_tensor, expected_tensor, atol=atol) + if not success: + max_diff = torch.abs(output_tensor - expected_tensor).max().item() + shmem.error(f"Rank {rank}: Validation failed, max diff: {max_diff}") + + if success: + shmem.info("All-gather validation passed!") + else: + shmem.error("All-gather validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + run_experiment() + shmem.barrier() + + for k in ["all_gather"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + output_tensor.zero_() + shmem.barrier() + + # Reinitialize input data + val = float(rank + 1) + input_tensor.fill_(val) + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate bandwidth + # In all-gather, each rank sends its (M, N) tensor to all ranks + # Total bytes sent = (world_size - 1) * M * N * element_size (excluding local copy) + # Total bytes received = (world_size - 1) * M * N * element_size + # Total bytes = (world_size - 1) * M * N * element_size + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = (world_size - 1) * M * N * element_size + total_bytes_gb = total_bytes / (1024**3) + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["all_gather"]["ms"] / kernel_timing["all_gather"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"All-gather (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "all_gather_ms", kernel_timing["all_gather"]["ms"] / kernel_timing["all_gather"]["experiments"] + ) + json_writer.add_field("all_gather_experiments", kernel_timing["all_gather"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark RCCL (PyTorch all_gather_into_tensor) for comparison + if args.get("benchmark_rccl", False): + shmem.info("Benchmarking PyTorch RCCL (all_gather_into_tensor)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_input = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_input.fill_(float(rank + 1)) + pytorch_output = torch.zeros(world_size * M, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.all_gather_into_tensor(pytorch_output, pytorch_input) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + pytorch_output.zero_() + pytorch_input.fill_(float(rank + 1)) + dist.barrier() + + rccl_start = torch.cuda.Event(enable_timing=True) + rccl_end = torch.cuda.Event(enable_timing=True) + + num_iterations = 126 # Match Iris benchmark iterations + dist.barrier() + rccl_start.record() + for _ in range(num_iterations): + dist.all_gather_into_tensor(pytorch_output, pytorch_input) + rccl_end.record() + torch.cuda.synchronize() + dist.barrier() + + rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = (world_size - 1) * M * N * element_size + total_bytes_gb = total_bytes / (1024**3) + rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3) + + shmem.info( + f"RCCL all_gather_into_tensor (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_bandwidth = bandwidth_gbps + rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0 + shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%") + + json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps) + json_writer.add_field("rccl_ms", rccl_ms) + json_writer.add_field("rccl_ratio_percent", rccl_ratio) + + # Wait for all to finish RCCL benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29234" + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ccl/all_reduce/benchmark.py b/benchmark/ccl/all_reduce/benchmark.py index edecd1c8..73af6e05 100755 --- a/benchmark/ccl/all_reduce/benchmark.py +++ b/benchmark/ccl/all_reduce/benchmark.py @@ -47,16 +47,21 @@ def parse_args(): help="Output file", ) parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") - parser.add_argument("--comm_sms", type=int, default=32, help="Number of SMs for all-reduce kernel") - parser.add_argument("--block_size_m", type=int, default=None, help="Block size for M dimension tiling") - parser.add_argument("--block_size_n", type=int, default=None, help="Block size for N dimension tiling") - parser.add_argument("--swizzle_size", type=int, default=None, help="Number of tiles to swizzle together") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-reduce kernel") + parser.add_argument( + "--benchmark_rccl", + action="store_true", + help="Also benchmark PyTorch RCCL (all_reduce) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=64, help="Block size for M dimension tiling") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling") + parser.add_argument("--swizzle_size", type=int, default=4, help="Number of tiles to swizzle together") parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") parser.add_argument( "--variant", type=str, - default="atomic", + default="two_shot", choices=["atomic", "ring", "two_shot", "one_shot", "spinlock"], help="All-reduce variant to use", ) @@ -79,6 +84,9 @@ def parse_args(): default=None, help="Column slice size for ring variant (power of two, must divide block_size_n)", ) + parser.add_argument( + "--init_url", type=str, default="tcp://127.0.0.1:29527", help="Initialization URL for distributed setup" + ) return vars(parser.parse_args()) @@ -95,10 +103,8 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): ) shmem = iris.iris(args["heap_size"]) - rank = shmem.get_rank() world_size = shmem.get_num_ranks() - # Datatype mapping datatype = torch.float32 if args["datatype"] == "fp16": @@ -300,6 +306,62 @@ def run_experiment(): # Wait for all to finish benchmarking shmem.barrier() + # Benchmark RCCL (PyTorch all_reduce) for comparison + if args.get("benchmark_rccl", False): + shmem.info("Benchmarking PyTorch RCCL (all_reduce)...") + + # Create PyTorch tensors (not on Iris heap) + pytorch_tensor = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_tensor.fill_(float(rank + 1)) + + # Warmup + for _ in range(10): + dist.all_reduce(pytorch_tensor, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + pytorch_tensor.fill_(float(rank + 1)) + dist.barrier() + + rccl_start = torch.cuda.Event(enable_timing=True) + rccl_end = torch.cuda.Event(enable_timing=True) + + num_iterations = 126 # Match Iris benchmark iterations + dist.barrier() + rccl_start.record() + for _ in range(num_iterations): + dist.all_reduce(pytorch_tensor, op=dist.ReduceOp.SUM) + rccl_end.record() + torch.cuda.synchronize() + dist.barrier() + + rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations + element_size = torch.tensor([], dtype=datatype).element_size() + # RCCL all-reduce: same bandwidth calculation as Iris + # All-reduce moves 2 * (world_size - 1) / world_size * data_size bytes + total_bytes = M * N * element_size * (2 * (world_size - 1)) / world_size + total_bytes_gb = total_bytes / (1024**3) + rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3) + + shmem.info( + f"RCCL all_reduce (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_bandwidth = bandwidth_gbps + rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0 + shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%") + + json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps) + json_writer.add_field("rccl_ms", rccl_ms) + json_writer.add_field("rccl_ratio_percent", rccl_ratio) + + # Wait for all to finish RCCL benchmarking + shmem.barrier() + if rank == 0: if args["variant"] == "ring": json_writer.add_field("all_reduce_ring_slice_n", config.all_reduce_ring_slice_n) @@ -313,7 +375,7 @@ def run_experiment(): def main(): args = parse_args() num_ranks = args["num_ranks"] - init_url = "tcp://127.0.0.1:29503" + init_url = args["init_url"] mp.spawn( fn=_worker, diff --git a/benchmark/ccl/all_to_all/benchmark.py b/benchmark/ccl/all_to_all/benchmark.py index b9af8689..4676bfd6 100644 --- a/benchmark/ccl/all_to_all/benchmark.py +++ b/benchmark/ccl/all_to_all/benchmark.py @@ -55,13 +55,18 @@ def parse_args(): help="Output file", ) parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") - parser.add_argument("--comm_sms", type=int, default=32, help="Number of SMs for all-to-all kernel") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for all-to-all kernel") parser.add_argument("--block_size_m", type=int, default=None, help="Block size for M dimension tiling") parser.add_argument("--block_size_n", type=int, default=None, help="Block size for N dimension tiling") parser.add_argument("--swizzle_size", type=int, default=None, help="Number of tiles to swizzle together") parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping") + parser.add_argument( + "--benchmark_rccl", + action="store_true", + help="Also benchmark PyTorch RCCL (all_to_all) for comparison", + ) return vars(parser.parse_args()) @@ -140,7 +145,10 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict): output_concat = shmem.zeros((M, N * world_size), dtype=datatype) expected_concat = shmem.zeros((M, N * world_size), dtype=datatype) - for target_rank in range(world_size): + # Determine which ranks to communicate with + comm_ranks = list(range(world_size)) + + for target_rank in comm_ranks: # Input: rank sends data at position (target_rank * N) val = float(rank * 1000 + target_rank) input_concat[:, target_rank * N : (target_rank + 1) * N] = val @@ -190,7 +198,7 @@ def run_experiment(): shmem.barrier() # Reinitialize input data - for target_rank in range(world_size): + for target_rank in comm_ranks: val = float(rank * 1000 + target_rank) input_concat[:, target_rank * N : (target_rank + 1) * N] = val shmem.barrier() @@ -229,7 +237,7 @@ def run_experiment(): shmem.barrier() # Reinitialize input data - for target_rank in range(world_size): + for target_rank in comm_ranks: val = float(rank * 1000 + target_rank) input_concat[:, target_rank * N : (target_rank + 1) * N] = val shmem.barrier() @@ -265,6 +273,69 @@ def run_experiment(): # Wait for all to finish benchmarking shmem.barrier() + # Benchmark RCCL (PyTorch all_to_all) for comparison + if args.get("benchmark_rccl", False): + shmem.info("Benchmarking PyTorch RCCL (all_to_all)...") + + # Create PyTorch tensors (not on Iris heap) + # For all_to_all, we need a list of tensors to send and receive + pytorch_input_list = [torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + pytorch_output_list = [torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") for _ in range(world_size)] + + # Fill input tensors with deterministic values + for target_rank in range(world_size): + val = float(rank * 1000 + target_rank) + pytorch_input_list[target_rank].fill_(val) + + # Warmup + for _ in range(10): + dist.all_to_all(pytorch_output_list, pytorch_input_list) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + for target_rank in range(world_size): + pytorch_output_list[target_rank].zero_() + val = float(rank * 1000 + target_rank) + pytorch_input_list[target_rank].fill_(val) + dist.barrier() + + rccl_start = torch.cuda.Event(enable_timing=True) + rccl_end = torch.cuda.Event(enable_timing=True) + + num_iterations = 126 # Match Iris benchmark iterations + dist.barrier() + rccl_start.record() + for _ in range(num_iterations): + dist.all_to_all(pytorch_output_list, pytorch_input_list) + rccl_end.record() + torch.cuda.synchronize() + dist.barrier() + + rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = (world_size - 1) * M * N * element_size + total_bytes_gb = total_bytes / (1024**3) + rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3) + + shmem.info( + f"RCCL all_to_all (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_bandwidth = bandwidth_gbps + rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0 + shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%") + + json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps) + json_writer.add_field("rccl_ms", rccl_ms) + json_writer.add_field("rccl_ratio_percent", rccl_ratio) + + # Wait for all to finish RCCL benchmarking + shmem.barrier() + if rank == 0: json_writer.flush() json_writer.display() @@ -276,7 +347,7 @@ def run_experiment(): def main(): args = parse_args() num_ranks = args["num_ranks"] - init_url = "tcp://127.0.0.1:29503" + init_url = "tcp://127.0.0.1:29569" mp.spawn( fn=_worker, diff --git a/benchmark/ccl/comprehensive_sweep.py b/benchmark/ccl/comprehensive_sweep.py new file mode 100644 index 00000000..a5773376 --- /dev/null +++ b/benchmark/ccl/comprehensive_sweep.py @@ -0,0 +1,443 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Comprehensive CCL benchmark with CU sweep across all operations. + +This benchmark runs all_gather, all_reduce, all_to_all, and reduce_scatter +with a sweep across different numbers of CUs (comm_sms) and outputs results to CSV. +Runs each benchmark as a separate subprocess to avoid memory accumulation. +""" + +import subprocess +import argparse +import csv +import os +from datetime import datetime +from typing import Dict, List +import json +import tempfile + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Comprehensive CCL benchmark with CU sweep.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Problem size + parser.add_argument("-m", type=int, default=16384, help="Number of rows in tensors") + parser.add_argument("-n", type=int, default=16384, help="Number of columns in tensors") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + + # CU sweep parameters + parser.add_argument( + "--min_cus", + type=int, + default=8, + help="Minimum number of CUs (comm_sms) to test", + ) + parser.add_argument( + "--max_cus", + type=int, + default=128, + help="Maximum number of CUs (comm_sms) to test", + ) + parser.add_argument( + "--cu_step", + type=int, + default=8, + help="Step size for CU sweep", + ) + + # Operations to benchmark + parser.add_argument( + "--operations", + type=str, + nargs="+", + default=["all_gather", "all_reduce", "all_to_all", "reduce_scatter"], + choices=["all_gather", "all_reduce", "all_to_all", "reduce_scatter"], + help="CCL operations to benchmark", + ) + + # All-Gather configuration + parser.add_argument("--all_gather_block_size_m", type=int, default=32, help="All-Gather: Block size M") + parser.add_argument("--all_gather_block_size_n", type=int, default=64, help="All-Gather: Block size N") + parser.add_argument("--all_gather_swizzle_size", type=int, default=4, help="All-Gather: Swizzle size") + + # All-Reduce configuration + parser.add_argument("--all_reduce_block_size_m", type=int, default=32, help="All-Reduce: Block size M") + parser.add_argument("--all_reduce_block_size_n", type=int, default=64, help="All-Reduce: Block size N") + parser.add_argument("--all_reduce_swizzle_size", type=int, default=4, help="All-Reduce: Swizzle size") + parser.add_argument( + "--all_reduce_variant", + type=str, + default="two_shot", + choices=["atomic", "spinlock", "ring", "two_shot", "one_shot"], + help="All-Reduce: Variant to use", + ) + parser.add_argument( + "--all_reduce_distribution", + type=int, + default=1, + choices=[0, 1], + help="All-Reduce: Distribution mode (0=striding, 1=block)", + ) + + # All-to-All configuration + parser.add_argument("--all_to_all_block_size_m", type=int, default=32, help="All-to-All: Block size M") + parser.add_argument("--all_to_all_block_size_n", type=int, default=128, help="All-to-All: Block size N") + parser.add_argument("--all_to_all_swizzle_size", type=int, default=4, help="All-to-All: Swizzle size") + + # Reduce-Scatter configuration + parser.add_argument("--reduce_scatter_block_size_m", type=int, default=32, help="Reduce-Scatter: Block size M") + parser.add_argument("--reduce_scatter_block_size_n", type=int, default=64, help="Reduce-Scatter: Block size N") + parser.add_argument("--reduce_scatter_swizzle_size", type=int, default=4, help="Reduce-Scatter: Swizzle size") + parser.add_argument( + "--reduce_scatter_distribution", + type=int, + default=1, + choices=[0, 1], + help="Reduce-Scatter: Distribution mode (0=striding, 1=block)", + ) + + # General configuration + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + + # Output + parser.add_argument( + "--output_csv", + type=str, + default=None, + help="Output CSV file (default: auto-generated with timestamp)", + ) + parser.add_argument("--benchmark_rccl", action="store_true", help="Also benchmark RCCL for comparison") + parser.add_argument("--validate", action="store_false", help="Run validation before benchmarking") + parser.add_argument("--skip_on_validation_failure", action="store_true", help="Skip benchmark if validation fails") + + return vars(parser.parse_args()) + + +def run_validation(operation, comm_sms, args): + """Run validation for a single operation.""" + # Get the directory where this script is located + script_dir = os.path.dirname(os.path.abspath(__file__)) + iris_root = os.path.dirname(os.path.dirname(script_dir)) + + script_map = { + "all_gather": os.path.join(iris_root, "benchmark/ccl/all_gather/benchmark.py"), + "all_reduce": os.path.join(iris_root, "benchmark/ccl/all_reduce/benchmark.py"), + "all_to_all": os.path.join(iris_root, "benchmark/ccl/all_to_all/benchmark.py"), + "reduce_scatter": os.path.join(iris_root, "benchmark/ccl/reduce_scatter/benchmark.py"), + } + + script_path = script_map[operation] + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + temp_output = f.name + + cmd = [ + "python", + script_path, + "-m", + str(args["m"]), + "-n", + str(args["n"]), + "--datatype", + args["datatype"], + "--comm_sms", + str(comm_sms), + "-r", + str(args["num_ranks"]), + "--heap_size", + str(args["heap_size"]), + "--validate", + "--output_file", + temp_output, + ] + + # Add operation-specific parameters (same as benchmark) + if operation == "all_gather": + cmd.extend(["--block_size_m", str(args["all_gather_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_gather_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_gather_swizzle_size"])]) + elif operation == "all_reduce": + cmd.extend(["--block_size_m", str(args["all_reduce_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_reduce_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_reduce_swizzle_size"])]) + cmd.extend(["--variant", args["all_reduce_variant"]]) + cmd.extend(["--distribution", str(args["all_reduce_distribution"])]) + elif operation == "all_to_all": + cmd.extend(["--block_size_m", str(args["all_to_all_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_to_all_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_to_all_swizzle_size"])]) + elif operation == "reduce_scatter": + cmd.extend(["--block_size_m", str(args["reduce_scatter_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["reduce_scatter_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["reduce_scatter_swizzle_size"])]) + cmd.extend(["--all_reduce_distribution", str(args["reduce_scatter_distribution"])]) + + if args["num_xcds"] is not None: + cmd.extend(["--num_xcds", str(args["num_xcds"])]) + + print(f" Validating {operation} with comm_sms={comm_sms}...") + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True) + + with open(temp_output, "r") as f: + data = json.load(f) + + os.unlink(temp_output) + + success = data.get("success", False) + return success + except subprocess.CalledProcessError as e: + print(f" Validation failed for {operation}: {e}") + if os.path.exists(temp_output): + os.unlink(temp_output) + return False + except Exception as e: + print(f" Error during validation for {operation}: {e}") + if os.path.exists(temp_output): + os.unlink(temp_output) + return False + + +def run_benchmark(operation, comm_sms, args): + """Run a single benchmark as a subprocess and return the results.""" + # Get the directory where this script is located + script_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up two levels to get to the iris root directory + iris_root = os.path.dirname(os.path.dirname(script_dir)) + + # Map operation to benchmark script (relative to iris root) + script_map = { + "all_gather": os.path.join(iris_root, "benchmark/ccl/all_gather/benchmark.py"), + "all_reduce": os.path.join(iris_root, "benchmark/ccl/all_reduce/benchmark.py"), + "all_to_all": os.path.join(iris_root, "benchmark/ccl/all_to_all/benchmark.py"), + "reduce_scatter": os.path.join(iris_root, "benchmark/ccl/reduce_scatter/benchmark.py"), + } + + script_path = script_map[operation] + + # Create temporary output file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + temp_output = f.name + + # Build command + cmd = [ + "python", + script_path, + "-m", + str(args["m"]), + "-n", + str(args["n"]), + "--datatype", + args["datatype"], + "--comm_sms", + str(comm_sms), + "-r", + str(args["num_ranks"]), + "--heap_size", + str(args["heap_size"]), + "--benchmark", + "--output_file", + temp_output, + ] + + # Add operation-specific parameters + if operation == "all_gather": + cmd.extend(["--block_size_m", str(args["all_gather_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_gather_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_gather_swizzle_size"])]) + elif operation == "all_reduce": + cmd.extend(["--block_size_m", str(args["all_reduce_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_reduce_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_reduce_swizzle_size"])]) + cmd.extend(["--variant", args["all_reduce_variant"]]) + cmd.extend(["--distribution", str(args["all_reduce_distribution"])]) + elif operation == "all_to_all": + cmd.extend(["--block_size_m", str(args["all_to_all_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["all_to_all_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["all_to_all_swizzle_size"])]) + elif operation == "reduce_scatter": + cmd.extend(["--block_size_m", str(args["reduce_scatter_block_size_m"])]) + cmd.extend(["--block_size_n", str(args["reduce_scatter_block_size_n"])]) + cmd.extend(["--swizzle_size", str(args["reduce_scatter_swizzle_size"])]) + cmd.extend(["--all_reduce_distribution", str(args["reduce_scatter_distribution"])]) + + if args["num_xcds"] is not None: + cmd.extend(["--num_xcds", str(args["num_xcds"])]) + + # Add --benchmark_rccl flag if requested + if args.get("benchmark_rccl", False): + cmd.append("--benchmark_rccl") + + # Set NCCL environment variables to control number of channels (CUs) + env = os.environ.copy() + if args.get("benchmark_rccl", False): + env["NCCL_MIN_NCHANNELS"] = str(comm_sms) + env["NCCL_MAX_NCHANNELS"] = str(comm_sms) + + # Run benchmark + print(f"\nRunning {operation} with comm_sms={comm_sms}...") + try: + result = subprocess.run(cmd, capture_output=True, text=True, check=True, env=env) + + # Read results from JSON file + with open(temp_output, "r") as f: + data = json.load(f) + + # Clean up temp file + os.unlink(temp_output) + + return data + except subprocess.CalledProcessError as e: + print(f"Error running {operation}: {e}") + print(f"stdout: {e.stdout}") + print(f"stderr: {e.stderr}") + if os.path.exists(temp_output): + os.unlink(temp_output) + return None + except Exception as e: + print(f"Error processing results for {operation}: {e}") + if os.path.exists(temp_output): + os.unlink(temp_output) + return None + + +def main(): + args = parse_args() + + # Generate CU sweep range + cu_values = list(range(args["min_cus"], args["max_cus"] + 1, args["cu_step"])) + + results = [] + + print(f"{'=' * 80}") + print("Comprehensive CCL Benchmark Sweep") + print(f"Operations: {', '.join(args['operations'])}") + print(f"CU range: {args['min_cus']} to {args['max_cus']} (step {args['cu_step']})") + print(f"Problem size: {args['m']}x{args['n']}") + print(f"Datatype: {args['datatype']}") + print(f"Ranks: {args['num_ranks']}") + print(f"{'=' * 80}") + + for comm_sms in cu_values: + print(f"\n{'=' * 80}") + print(f"Testing with comm_sms={comm_sms}") + print(f"{'=' * 80}") + + for operation in args["operations"]: + # Run validation if requested + validation_passed = True + if args.get("validate", False): + validation_passed = run_validation(operation, comm_sms, args) + if validation_passed: + print(f" ✓ Validation passed for {operation}") + else: + print(f" ✗ Validation FAILED for {operation}") + if args.get("skip_on_validation_failure", False): + print(f" Skipping benchmark for {operation} due to validation failure") + continue + + # Run benchmark + data = run_benchmark(operation, comm_sms, args) + + if data is not None: + # Add validation status to result + if args.get("validate", False): + validation_status = "passed" if validation_passed else "failed" + else: + validation_status = "not_run" + # Extract relevant fields and add to results + result = { + "operation": operation, + "comm_sms": comm_sms, + "m": args["m"], + "n": args["n"], + "world_size": args["num_ranks"], + "datatype": args["datatype"], + "block_size_m": data.get("block_size_m"), + "block_size_n": data.get("block_size_n"), + "swizzle_size": data.get("swizzle_size"), + "num_xcds": data.get("num_xcds"), + "iris_latency_ms": data.get(f"{operation}_ms"), + "iris_bandwidth_gbps": data.get("bandwidth_gbps"), + } + + # Add operation-specific fields + if operation == "all_reduce": + result["variant"] = args["all_reduce_variant"] + result["distribution"] = args["all_reduce_distribution"] + elif operation == "reduce_scatter": + result["distribution"] = args["reduce_scatter_distribution"] + + # Add RCCL results if available + if args.get("benchmark_rccl", False): + result["rccl_latency_ms"] = data.get("rccl_ms") + result["rccl_bandwidth_gbps"] = data.get("rccl_bandwidth_gbps") + result["iris_vs_rccl_ratio"] = data.get("rccl_ratio_percent", 0) / 100.0 + + results.append(result) + + print(f" Iris: {result['iris_latency_ms']:.3f} ms, {result['iris_bandwidth_gbps']:.3f} GB/s") + if args.get("benchmark_rccl", False) and result.get("rccl_bandwidth_gbps"): + print(f" RCCL: {result['rccl_latency_ms']:.3f} ms, {result['rccl_bandwidth_gbps']:.3f} GB/s") + print(f" Ratio: {result['iris_vs_rccl_ratio']:.2f}x") + + # Generate output filename if not provided + if args["output_csv"] is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + args["output_csv"] = f"ccl_sweep_{timestamp}.csv" + + # Write results to CSV + if results: + # Collect all unique fieldnames from all results + all_fieldnames = set() + for result in results: + all_fieldnames.update(result.keys()) + + # Sort fieldnames for consistent column order + # Put common fields first, then operation-specific fields + common_fields = [ + "operation", + "comm_sms", + "m", + "n", + "world_size", + "datatype", + "block_size_m", + "block_size_n", + "swizzle_size", + "num_xcds", + "iris_latency_ms", + "iris_bandwidth_gbps", + ] + optional_fields = sorted(all_fieldnames - set(common_fields)) + fieldnames = [f for f in common_fields if f in all_fieldnames] + optional_fields + + with open(args["output_csv"], "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(results) + + print(f"\n{'=' * 80}") + print(f"Results written to: {args['output_csv']}") + print(f"Total benchmarks run: {len(results)}") + print(f"{'=' * 80}\n") + else: + print("\nNo results collected!") + + +if __name__ == "__main__": + main() diff --git a/benchmark/ccl/plot_sweep_results.py b/benchmark/ccl/plot_sweep_results.py new file mode 100644 index 00000000..16c3ed5c --- /dev/null +++ b/benchmark/ccl/plot_sweep_results.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Plot comprehensive CCL benchmark sweep results. + +This script reads the CSV output from comprehensive_sweep.py and creates +subplots comparing Iris vs RCCL bandwidth for each collective operation. +""" + +import argparse +import csv +import matplotlib.pyplot as plt +import numpy as np +from collections import defaultdict +import os + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Plot CCL benchmark sweep results.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "input_csv", + type=str, + help="Input CSV file from comprehensive_sweep.py", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output plot file (default: auto-generated from input filename)", + ) + parser.add_argument( + "--title", + type=str, + default="CCL Benchmark: Iris vs RCCL", + help="Overall plot title", + ) + parser.add_argument( + "--dpi", + type=int, + default=150, + help="DPI for output image", + ) + parser.add_argument( + "--figsize", + type=int, + nargs=2, + default=[16, 10], + help="Figure size in inches (width height)", + ) + + return parser.parse_args() + + +def load_results(csv_file): + """Load results from CSV file and organize by operation.""" + data = defaultdict(lambda: {"comm_sms": [], "iris_bw": [], "rccl_bw": []}) + + with open(csv_file, "r") as f: + reader = csv.DictReader(f) + for row in reader: + operation = row["operation"] + comm_sms = int(row["comm_sms"]) + iris_bw = float(row["iris_bandwidth_gbps"]) + + data[operation]["comm_sms"].append(comm_sms) + data[operation]["iris_bw"].append(iris_bw) + + # RCCL data may not be present for all operations + if "rccl_bandwidth_gbps" in row and row["rccl_bandwidth_gbps"]: + rccl_bw = float(row["rccl_bandwidth_gbps"]) + data[operation]["rccl_bw"].append(rccl_bw) + else: + data[operation]["rccl_bw"].append(None) + + return data + + +def plot_results(data, args): + """Create subplots comparing Iris vs RCCL for each operation.""" + operations = sorted(data.keys()) + num_ops = len(operations) + + # Create subplots - 2x2 grid for up to 4 operations + if num_ops <= 2: + nrows, ncols = 1, num_ops + elif num_ops <= 4: + nrows, ncols = 2, 2 + else: + nrows = (num_ops + 1) // 2 + ncols = 2 + + fig, axes = plt.subplots(nrows, ncols, figsize=tuple(args.figsize)) + fig.suptitle(args.title, fontsize=16, fontweight="bold") + + # Flatten axes for easier iteration + if num_ops == 1: + axes = [axes] + else: + axes = axes.flatten() if num_ops > 1 else [axes] + + for idx, operation in enumerate(operations): + ax = axes[idx] + op_data = data[operation] + + comm_sms = np.array(op_data["comm_sms"]) + iris_bw = np.array(op_data["iris_bw"]) + rccl_bw = np.array(op_data["rccl_bw"]) + + # Plot Iris bandwidth + ax.plot(comm_sms, iris_bw, "o-", linewidth=2, markersize=8, label="Iris", color="#2E86AB") + + # Plot RCCL bandwidth if available + if not all(x is None for x in rccl_bw): + # Filter out None values + valid_indices = [i for i, x in enumerate(rccl_bw) if x is not None] + if valid_indices: + rccl_comm_sms = comm_sms[valid_indices] + rccl_bw_valid = rccl_bw[valid_indices] + ax.plot(rccl_comm_sms, rccl_bw_valid, "s--", linewidth=2, markersize=8, label="RCCL", color="#A23B72") + + # Formatting + ax.set_xlabel("Number of CUs (comm_sms)", fontsize=11) + ax.set_ylabel("Bandwidth (GB/s)", fontsize=11) + ax.set_title(f"{operation.replace('_', '-').title()}", fontsize=13, fontweight="bold") + ax.grid(True, alpha=0.3, linestyle="--") + ax.legend(loc="best", fontsize=10) + + # Set x-axis to show all CU values + ax.set_xticks(comm_sms) + + # Add some padding to y-axis + y_min = min( + iris_bw.min(), rccl_bw[rccl_bw is not None].min() if any(x is not None for x in rccl_bw) else iris_bw.min() + ) + y_max = max( + iris_bw.max(), rccl_bw[rccl_bw is not None].max() if any(x is not None for x in rccl_bw) else iris_bw.max() + ) + y_range = y_max - y_min + ax.set_ylim(y_min - 0.1 * y_range, y_max + 0.1 * y_range) + + # Hide unused subplots + for idx in range(num_ops, len(axes)): + axes[idx].set_visible(False) + + plt.tight_layout() + + # Generate output filename if not provided + if args.output is None: + base_name = os.path.splitext(args.input_csv)[0] + args.output = f"{base_name}_plot.png" + + plt.savefig(args.output, dpi=args.dpi, bbox_inches="tight") + print(f"\nPlot saved to: {args.output}") + + # Also display if running interactively + try: + plt.show() + except Exception: + pass + + +def main(): + args = parse_args() + + print(f"Loading results from: {args.input_csv}") + data = load_results(args.input_csv) + + print(f"Found {len(data)} operations:") + for op in sorted(data.keys()): + num_points = len(data[op]["comm_sms"]) + print(f" - {op}: {num_points} data points") + + print("\nCreating plots...") + plot_results(data, args) + + +if __name__ == "__main__": + main() diff --git a/benchmark/ccl/reduce_scatter/benchmark.py b/benchmark/ccl/reduce_scatter/benchmark.py new file mode 100755 index 00000000..61bb9991 --- /dev/null +++ b/benchmark/ccl/reduce_scatter/benchmark.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Benchmark for iris-ccl reduce-scatter collective operation. + +This benchmark showcases the reduce-scatter collective and reports achieved bandwidth. +""" + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import random +import argparse + +from examples.common.utils import JSONWriter + +import iris +from iris.ccl import Config + +# Conditional import for Gluon +try: + import iris.experimental.iris_gluon as iris_gluon + + GLUON_AVAILABLE = True +except ImportError: + GLUON_AVAILABLE = False + +torch.manual_seed(123) +random.seed(123) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark reduce-scatter collective operation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument("-m", type=int, default=16384, help="Number of rows in input tensors") + parser.add_argument("-n", type=int, default=16384, help="Number of columns in input tensors") + parser.add_argument("-d", "--debug", action="store_true", help="Enable debug mode") + parser.add_argument("-v", "--validate", action="store_true", help="Enable validation mode") + parser.add_argument("-b", "--benchmark", action="store_true", help="Enable benchmarking mode") + parser.add_argument( + "--datatype", + type=str, + default="fp16", + choices=["fp16", "fp32", "bf16"], + help="Datatype of tensors", + ) + parser.add_argument( + "--output_file", + type=str, + default="log.json", + help="Output file", + ) + parser.add_argument("--heap_size", type=int, default=1 << 34, help="Iris heap size") + parser.add_argument("--comm_sms", type=int, default=64, help="Number of SMs for reduce-scatter kernel") + parser.add_argument( + "--benchmark_rccl", + action="store_true", + help="Also benchmark PyTorch RCCL (reduce_scatter) for comparison", + ) + parser.add_argument("--block_size_m", type=int, default=64, help="Block size for M dimension tiling (default: 64)") + parser.add_argument("--block_size_n", type=int, default=64, help="Block size for N dimension tiling (default: 64)") + parser.add_argument("--swizzle_size", type=int, default=8, help="Number of tiles to swizzle together (default: 8)") + parser.add_argument("--num_xcds", type=int, default=None, help="Number of XCDs (auto-detected if not set)") + parser.add_argument( + "--all_reduce_distribution", + type=int, + default=0, + choices=[0, 1], + help="Distribution mode for two-shot reduce-scatter: 0=striding (default), 1=block", + ) + parser.add_argument("-r", "--num_ranks", type=int, default=8, help="Number of ranks/processes") + parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping") + + return vars(parser.parse_args()) + + +def _worker(local_rank: int, world_size: int, init_url: str, args: dict): + """Worker function for PyTorch distributed execution.""" + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group( + backend=backend, + init_method=init_url, + world_size=world_size, + rank=local_rank, + device_id=torch.device(f"cuda:{local_rank}"), + ) + + # Use Gluon if requested and available + if args.get("use_gluon", False): + if not GLUON_AVAILABLE: + raise RuntimeError("Gluon is not available. Install Triton with Gluon support or remove --use_gluon flag") + shmem = iris_gluon.iris(args["heap_size"]) + else: + shmem = iris.iris(args["heap_size"]) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # Datatype mapping + datatype = torch.float32 + if args["datatype"] == "fp16": + datatype = torch.float16 + elif args["datatype"] == "fp32": + datatype = torch.float32 + elif args["datatype"] == "bf16": + datatype = torch.bfloat16 + else: + print("Unknown datatype.") + exit(1) + + M = args["m"] + N = args["n"] + + # Create config with optimized defaults for reduce-scatter + config_kwargs = { + "comm_sms": args["comm_sms"], + "all_reduce_distribution": args["all_reduce_distribution"], + "block_size_m": args["block_size_m"], + "block_size_n": args["block_size_n"], + "swizzle_size": args["swizzle_size"], + } + if args["num_xcds"] is not None: + config_kwargs["num_xcds"] = args["num_xcds"] + if args.get("use_gluon", False): + config_kwargs["use_gluon"] = True + + config = Config(**config_kwargs) + + json_writer = JSONWriter(args["output_file"]) + json_writer.add_field("world_size", world_size) + + for key, value in args.items(): + json_writer.add_field(key, value) + + # Export config values to JSON (use actual values from config, including defaults) + json_writer.add_field("block_size_m", config.block_size_m) + json_writer.add_field("block_size_n", config.block_size_n) + json_writer.add_field("swizzle_size", config.swizzle_size) + json_writer.add_field("num_xcds", config.num_xcds) + json_writer.add_field("use_gluon", config.use_gluon) + json_writer.add_field("all_reduce_distribution", config.all_reduce_distribution) + + # Create input and output tensors for reduce-scatter + # Input: each rank has (M, N) tensor + # Output: each rank has (M, N) tensor - contains reduced tiles assigned to this rank + # Note: Must use shmem.zeros() to allocate on Iris symmetric heap for iris.load() compatibility + input_tensor = shmem.zeros((M, N), dtype=datatype) + output_tensor = shmem.zeros((M, N), dtype=datatype) + expected_tensor = shmem.zeros((M, N), dtype=datatype) + + # Fill input with deterministic values + # For reduce-scatter, each rank's input contributes to the reduction + # Use smaller values to avoid overflow, especially with fp16 + val = float(rank + 1) * 0.1 # Scale down to prevent overflow + input_tensor.fill_(val) + + # Expected output: each rank gets the sum of all ranks' inputs for its assigned tiles + # Since reduce-scatter uses two-shot with tile assignment, we need to compute + # which tiles are assigned to each rank based on the distribution mode + # For validation, we'll use PyTorch's reduce_scatter as reference + comm_stream = torch.cuda.Stream() + + kernel_timing = { + "reduce_scatter": { + "start_event": torch.cuda.Event(enable_timing=True), + "end_event": torch.cuda.Event(enable_timing=True), + "ms": 0, + "experiments": 0, + }, + } + + def run_experiment(): + nonlocal kernel_timing + shmem.barrier() + + torch.cuda.nvtx.range_push("Reduce-Scatter") + with torch.cuda.stream(comm_stream): + kernel_timing["reduce_scatter"]["start_event"].record() + shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config, async_op=False) + kernel_timing["reduce_scatter"]["end_event"].record() + kernel_timing["reduce_scatter"]["experiments"] += 1 + torch.cuda.nvtx.range_pop() + + # Synchronize before querying event timing + shmem.barrier() + + # Update timing + ms = kernel_timing["reduce_scatter"]["start_event"].elapsed_time(kernel_timing["reduce_scatter"]["end_event"]) + kernel_timing["reduce_scatter"]["ms"] += ms + + # Synchronize across all GPUs + shmem.barrier() + + if args["validate"]: + shmem.info("Validating...") + + # Reset output before validation + output_tensor.zero_() + shmem.barrier() + + # Reinitialize input data + val = float(rank + 1) * 0.1 # Scale down to prevent overflow + input_tensor.fill_(val) + shmem.barrier() + + # Run Iris reduce_scatter + run_experiment() + torch.cuda.synchronize() + shmem.barrier() + + # Create reference output by manually computing expected reduce-scatter result + # Each rank should reduce its assigned tiles from all ranks' inputs + reference_output = shmem.zeros((M, N), dtype=datatype) + + # Compute reference: sum all ranks' inputs for tiles assigned to this rank + # This simulates what reduce_scatter should produce + for r in range(world_size): + # Create input for rank r + rank_input = shmem.zeros((M, N), dtype=datatype) + rank_input.fill_(float(r + 1) * 0.1) + + # Add to reference (all tiles get summed) + reference_output += rank_input + + # Now reference_output contains the sum of all inputs at each location + # In reduce_scatter, each rank only gets its assigned tiles (rest should be zero) + # But we can use this to validate the non-zero values + + # Validate using double precision to avoid overflow in sum computation + output_sum = output_tensor.double().sum().item() + input_sum = input_tensor.double().sum().item() + + # Expected: each tile location gets sum of all ranks' contributions + # For reduce-scatter, each rank gets its assigned tiles reduced + # The expected value at each reduced location is the sum of all ranks' inputs + expected_value_per_element = sum(float(r + 1) * 0.1 for r in range(world_size)) + + # Simple validation: output should be non-zero and have reasonable values + atol = 1e-3 if datatype == torch.float16 else 1e-5 + + # Count non-zero elements across entire tensor + non_zero_mask = output_tensor.abs() > atol + num_non_zero = non_zero_mask.sum().item() + total_elements = output_tensor.numel() + + # Get statistics on non-zero values and compare with reference + if num_non_zero > 0: + non_zero_values = output_tensor[non_zero_mask].double() + mean_value = non_zero_values.mean().item() + min_value = non_zero_values.min().item() + max_value = non_zero_values.max().item() + + # Compare with reference output + # For non-zero elements, they should match the reference (sum of all inputs) + reference_non_zero = reference_output[non_zero_mask].double() + + # Count how many elements match the reference (within tolerance) + match_tolerance = 1e-2 if datatype == torch.float16 else 1e-4 + matches = (non_zero_values - reference_non_zero).abs() < match_tolerance + num_matches = matches.sum().item() + match_percentage = (num_matches / num_non_zero) * 100 + + # Check that non-zero values are close to expected sum + expected_close = abs(mean_value - expected_value_per_element) < (expected_value_per_element * 0.2) + + if expected_close and match_percentage > 95: + success = True + shmem.info( + f"Rank {rank}: {num_non_zero}/{total_elements} non-zero elements, " + f"mean: {mean_value:.4f} (expected: {expected_value_per_element:.4f}), " + f"range: [{min_value:.4f}, {max_value:.4f}], " + f"matches reference: {num_matches}/{num_non_zero} ({match_percentage:.1f}%)" + ) + else: + shmem.error( + f"Rank {rank}: Validation failed - mean {mean_value:.4f} != expected {expected_value_per_element:.4f}, " + f"{num_non_zero}/{total_elements} non-zero, " + f"matches: {num_matches}/{num_non_zero} ({match_percentage:.1f}%)" + ) + success = False + else: + # No non-zero values - this might be valid if this rank has no assigned tiles + # In reduce-scatter, tiles are distributed across ranks, so some ranks might have fewer tiles + shmem.warning(f"Rank {rank}: No non-zero values found ({num_non_zero}/{total_elements})") + # Consider this a pass for now - the operation may have assigned no tiles to this rank + success = True + + if success: + shmem.info("Reduce-scatter validation passed!") + else: + shmem.error("Reduce-scatter validation failed!") + + json_writer.add_field("success", success) + + # Wait for all to finish validation + shmem.barrier() + + if args["benchmark"]: + # Warmup for benchmarking + run_experiment() + shmem.barrier() + + for k in ["reduce_scatter"]: + kernel_timing[k]["ms"] = 0 + kernel_timing[k]["experiments"] = 0 + + # Reset output before benchmarking + output_tensor.zero_() + shmem.barrier() + + # Reinitialize input data + val = float(rank + 1) * 0.1 # Scale down to prevent overflow + input_tensor.fill_(val) + shmem.barrier() + + shmem.info("Benchmarking...") + + # Calculate bandwidth + # Reduce-scatter moves (world_size - 1) / world_size * data_size bytes + # This accounts for the two-shot approach where each rank reads from all ranks + # and writes only to its own output (no broadcast phase) + # Each rank transfers (world_size - 1) / world_size * M * N * element_size bytes + # This is similar to all-reduce but without the broadcast phase + element_size = torch.tensor([], dtype=datatype).element_size() + total_bytes = M * N * element_size * (world_size - 1) / world_size + total_bytes_gb = total_bytes / (1024**3) + + triton_ms = iris.do_bench(run_experiment, shmem.barrier) + bandwidth_gbps = total_bytes_gb / ( + (kernel_timing["reduce_scatter"]["ms"] / kernel_timing["reduce_scatter"]["experiments"]) * 1e-3 + ) + + shmem.info( + f"Reduce-scatter (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{triton_ms:.3f} ms, {bandwidth_gbps:.3f} GB/s" + ) + + json_writer.add_field("bandwidth_gbps", bandwidth_gbps) + json_writer.add_field("total_ms", triton_ms) + json_writer.add_field("total_bytes", total_bytes) + json_writer.add_field("total_bytes_gb", total_bytes_gb) + json_writer.add_field( + "reduce_scatter_ms", kernel_timing["reduce_scatter"]["ms"] / kernel_timing["reduce_scatter"]["experiments"] + ) + json_writer.add_field("reduce_scatter_experiments", kernel_timing["reduce_scatter"]["experiments"]) + + # Wait for all to finish benchmarking + shmem.barrier() + + # Benchmark RCCL (PyTorch reduce_scatter_tensor) for comparison + if args.get("benchmark_rccl", False): + shmem.info("Benchmarking PyTorch RCCL (reduce_scatter_tensor)...") + + # Create PyTorch tensors (not on Iris heap) + # PyTorch reduce_scatter_tensor: input is (M, N), output is (M // world_size, N) + # Our implementation is different (tiles vs chunks), so we'll benchmark with same input size + pytorch_input = torch.zeros(M, N, dtype=datatype, device=f"cuda:{rank}") + pytorch_input.fill_(float(rank + 1) * 0.1) # Scale down to prevent overflow + + # PyTorch reduce_scatter_tensor splits along dim 0 + output_size_m = M // world_size + pytorch_output = torch.zeros(output_size_m, N, dtype=datatype, device=f"cuda:{rank}") + + # Warmup + for _ in range(10): + dist.reduce_scatter_tensor(pytorch_output, pytorch_input, op=dist.ReduceOp.SUM) + torch.cuda.synchronize() + dist.barrier() + + # Benchmark + pytorch_output.zero_() + pytorch_input.fill_(float(rank + 1) * 0.1) # Scale down to prevent overflow + dist.barrier() + + rccl_start = torch.cuda.Event(enable_timing=True) + rccl_end = torch.cuda.Event(enable_timing=True) + + num_iterations = 126 # Match Iris benchmark iterations + dist.barrier() + rccl_start.record() + for _ in range(num_iterations): + dist.reduce_scatter_tensor(pytorch_output, pytorch_input, op=dist.ReduceOp.SUM) + rccl_end.record() + torch.cuda.synchronize() + dist.barrier() + + rccl_ms = rccl_start.elapsed_time(rccl_end) / num_iterations + element_size = torch.tensor([], dtype=datatype).element_size() + # RCCL reduce-scatter: similar bandwidth calculation + # Each rank reads from all ranks and writes its output chunk + total_bytes = M * N * element_size * (world_size - 1) / world_size + total_bytes_gb = total_bytes / (1024**3) + rccl_bandwidth_gbps = total_bytes_gb / (rccl_ms * 1e-3) + + shmem.info( + f"RCCL reduce_scatter_tensor (M={M}, N={N}, world_size={world_size}, dtype={args['datatype']}): " + f"{rccl_ms:.3f} ms, {rccl_bandwidth_gbps:.3f} GB/s" + ) + + if args["benchmark"]: + # Calculate performance ratio + iris_bandwidth = bandwidth_gbps + rccl_ratio = (iris_bandwidth / rccl_bandwidth_gbps) * 100 if rccl_bandwidth_gbps > 0 else 0 + shmem.info(f"Performance ratio (Iris/RCCL): {rccl_ratio:.1f}%") + + json_writer.add_field("rccl_bandwidth_gbps", rccl_bandwidth_gbps) + json_writer.add_field("rccl_ms", rccl_ms) + json_writer.add_field("rccl_ratio_percent", rccl_ratio) + + # Wait for all to finish RCCL benchmarking + shmem.barrier() + + if rank == 0: + json_writer.flush() + json_writer.display() + + shmem.barrier() + dist.destroy_process_group() + + +def main(): + args = parse_args() + num_ranks = args["num_ranks"] + init_url = "tcp://127.0.0.1:29234" + + mp.spawn( + fn=_worker, + args=(num_ranks, init_url, args), + nprocs=num_ranks, + join=True, + ) + + +if __name__ == "__main__": + main() diff --git a/docker/Dockerfile b/docker/Dockerfile index 8b49c01a..6ae142dc 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,7 +1,7 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. -FROM rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch +FROM rocm/pytorch:rocm7.1_ubuntu24.04_py3.13_pytorch_release_2.9.1 # Use bash shell for RUN commands SHELL ["/bin/bash", "-c"] @@ -31,15 +31,15 @@ RUN pip3 install --upgrade pip && \ # Clone and install Triton WORKDIR $TRITON_PATH RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH -RUN git checkout dd5823453bcc7973eabadb65f9d827c43281c434 +RUN git checkout 715f6b1d442601436bf8d462db6ff8e17aec8cfb RUN pip3 install -e . ENV PYTHONPATH=$TRITON_PATH # Install rocprofiler-systems WORKDIR /workspace -RUN wget https://github.com/ROCm/rocprofiler-systems/releases/download/rocm-6.3.1/rocprofiler-systems-install.py && \ - python3 ./rocprofiler-systems-install.py --prefix /opt/rocprofiler-systems --rocm 6.3 && \ - rm -f rocprofiler-systems-install.py +# RUN wget https://github.com/ROCm/rocprofiler-systems/releases/latest/download/rocprofiler-systems-install.py && \ +# python3 ./rocprofiler-systems-install.py --prefix /opt/rocprofiler-systems --rocm 7.1 && \ +# rm -f rocprofiler-systems-install.py # Create entrypoint script RUN echo '#!/bin/bash' > /entrypoint.sh && \ diff --git a/iris/ccl/all_gather.py b/iris/ccl/all_gather.py new file mode 100644 index 00000000..ddc0a800 --- /dev/null +++ b/iris/ccl/all_gather.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +All-gather collective communication primitive for Iris. +Gathers tensors from all ranks and concatenates them along the last dimension. +""" + +import triton +import triton.language as tl +import torch +import iris +from .config import Config + + +@triton.jit() +def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr): + if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): + return pid + + local_pid = pid // num_xcds + chunk_idx = local_pid // chunk_size + pos_in_chunk = local_pid % chunk_size + + xcd = pid % num_xcds + new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk + return new_pid + + +@triton.jit() +def persistent_all_gather( + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, +): + """ + Persistent all-gather kernel. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + input_ptr: Pointer to input tensor (local rank's data to send) of shape (M, N) + output_ptr: Pointer to output tensor (will receive from all ranks) of shape (world_size * M, N) + M: Number of rows per rank (output will be world_size * M rows) + N: Number of columns + stride_in_m, stride_in_n: Strides for input tensor + stride_out_m, stride_out_n: Strides for output tensor + heap_bases: Heap base pointers for all ranks + cur_rank: Current rank + world_size: Total number of ranks + BLOCK_SIZE_M, BLOCK_SIZE_N: Block sizes for tiling + GROUP_SIZE_M: Group size for M dimension tiling + COMM_SMS: Number of SMs for communication + NUM_XCDS: Number of XCDs + CHUNK_SIZE: Chunk size for chiplet transform + """ + pid = tl.program_id(0) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + tl.assume(total_tiles > 0) + for tile_id in range(pid, total_tiles, COMM_SMS): + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(tile_id >= 0) + tl.assume(stride_in_m >= 0) + tl.assume(stride_in_n >= 0) + tl.assume(stride_out_m >= 0) + tl.assume(stride_out_n >= 0) + + # Compute local row and column indices for input tensor + rm_base = pid_m * BLOCK_SIZE_M + rn_base = pid_n * BLOCK_SIZE_N + rm_input = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm_input = tl.max_contiguous(tl.multiple_of(rm_input, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + # Mask for local input bounds + input_mask = (rm_input[:, None] < M) & (rn[None, :] < N) + + # Compute input offset and load local shard data once + # Each rank loads its own input data and then broadcasts it to all ranks + input_base_m = rm_input[:, None] * stride_in_m + input_base_n = rn[None, :] * stride_in_n + input_offset = input_base_m + input_base_n + input_ptr_source = input_ptr + input_offset + input_ptr_source = tl.multiple_of(input_ptr_source, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + # Load local input data once for this tile + data = tl.load(input_ptr_source, mask=input_mask, other=0.0) + + # Send local shard data to all destination ranks + # Each rank's input goes to output[cur_rank * M : (cur_rank + 1) * M, :] on all ranks + for rank in tl.static_range(world_size): + # Compute global output row indices: offset by cur_rank * M + # This rank's data should be placed at output[cur_rank * M : (cur_rank + 1) * M, :] + rm_output = rm_input + cur_rank * M + + # Output mask: check bounds for output tensor (world_size * M rows, N cols) + output_mask = (rm_output[:, None] < (world_size * M)) & (rn[None, :] < N) + + # Combine masks: must be valid in both input and output + combined_mask = input_mask & output_mask + + # Compute output offset: write to output at rows [cur_rank * M : (cur_rank + 1) * M] + # This is the same location on all destination ranks + output_base_m = rm_output[:, None] * stride_out_m + output_base_n = rn[None, :] * stride_out_n + output_offset = output_base_m + output_base_n + output_ptr_target = output_ptr + output_offset + output_ptr_target = tl.multiple_of(output_ptr_target, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + if rank == cur_rank: + # Local destination: use direct store + tl.store(output_ptr_target, data, cache_modifier=".wt") + else: + # Remote destination: use iris.put to send from local source to remote destination + # from_ptr: local input source, to_ptr: remote output destination + iris.put( + input_ptr_source, + output_ptr_target, + cur_rank, + rank, + heap_bases, + ) + + +def all_gather(output_tensor, input_tensor, shmem, config=None, async_op=False): + """ + Internal all-gather collective operation implementation. + + This function is called internally by shmem.ccl.all_gather(). + Users should use the Iris instance method instead: + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + shmem: Iris shmem context + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + """ + # Use provided config or create default one + if config is None: + config = Config() + + # Check for unsupported options + if config.use_gluon: + raise ValueError( + "all_gather does not support use_gluon=True. " + "Gluon implementation is not available for all_gather. " + "Use default config (use_gluon=False)." + ) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + M, N = input_tensor.shape[:2] + expected_output_shape = (world_size * M, N) + + if output_tensor.shape[:2] != expected_output_shape: + raise ValueError( + f"Output tensor shape {output_tensor.shape[:2]} does not match expected shape {expected_output_shape}. " + f"Expected (world_size * M, N) = ({world_size * M}, {N})" + ) + + stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1) + stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1) + + heap_bases = shmem.get_heap_bases() + + persistent_all_gather[(config.comm_sms,)]( + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases, + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + ) + + if not async_op: + shmem.barrier() diff --git a/iris/ccl/all_reduce.py b/iris/ccl/all_reduce.py index 1c96e985..6d3fc14c 100644 --- a/iris/ccl/all_reduce.py +++ b/iris/ccl/all_reduce.py @@ -544,7 +544,7 @@ def persistent_all_reduce_ring( ) -@triton.jit() +@triton.jit def persistent_all_reduce_two_shot( input_ptr, output_ptr, @@ -561,16 +561,15 @@ def persistent_all_reduce_two_shot( BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr, COMM_SMS: tl.constexpr, - NUM_XCDS: tl.constexpr, - CHUNK_SIZE: tl.constexpr, + NUM_XCDS: tl.constexpr, # unused here but kept for signature compatibility + CHUNK_SIZE: tl.constexpr, # unused here but kept for signature compatibility DISTRIBUTION: tl.constexpr, ): - """Reduce assigned tiles for a rank and broadcast the result to all peers.""" + """Reduce assigned tiles for a rank and broadcast the result to all peers. + Single kernel: unmasked fast path for full tiles, masked slow path for tails. + """ pid = tl.program_id(0) - if NUM_XCDS != 1: - pid = chiplet_transform_chunked(pid, COMM_SMS, NUM_XCDS, CHUNK_SIZE) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) total_tiles = num_pid_m * num_pid_n @@ -591,6 +590,7 @@ def persistent_all_reduce_two_shot( remaining = tl.maximum(remaining, 0) max_tile_offset = tl.minimum(tiles_per_rank, remaining) + # Persistent traversal for tile_offset in range(pid, max_tile_offset, COMM_SMS): tile_id = start_tile + tile_offset * stride @@ -601,45 +601,61 @@ def persistent_all_reduce_two_shot( pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - rm_base = pid_m * BLOCK_SIZE_M rn_base = pid_n * BLOCK_SIZE_N + + is_full = (rm_base + BLOCK_SIZE_M <= M) & (rn_base + BLOCK_SIZE_N <= N) + + # Build indices (used by both paths) rm = rm_base + tl.arange(0, BLOCK_SIZE_M) rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - mask = (rm[:, None] < M) & (rn[None, :] < N) input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n - acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for remote_rank in range(world_size): - partial = iris.load( - input_ptr + input_offset, - cur_rank, - remote_rank, - heap_bases, - mask=mask, - ) - acc += partial.to(acc_dtype) + base_ptr = input_ptr + input_offset + out_ptr = output_ptr + output_offset - reduced = acc.to(output_ptr.type.element_ty) + # Fast path: NO MASKS + if is_full: + mask = (rm[:, None] < M) & (rn[None, :] < N) - for remote_rank in range(world_size): - if remote_rank == cur_rank: - tl.store(output_ptr + output_offset, reduced, mask=mask, cache_modifier=".wt") - else: - iris.store( - output_ptr + output_offset, - reduced, - cur_rank, - remote_rank, - heap_bases, - mask=mask, - ) + start_rank = pid % world_size + acc = iris.load(base_ptr, cur_rank, start_rank, heap_bases).to(acc_dtype) + for i in tl.static_range(1, world_size): + remote_rank = (start_rank + i) % world_size + acc += iris.load(base_ptr, cur_rank, remote_rank, heap_bases).to(acc_dtype) + + reduced = acc.to(output_ptr.type.element_ty) + + tl.store(out_ptr, reduced, cache_modifier=".wt") + + for i in tl.static_range(0, world_size): + remote_rank = (start_rank + i) % world_size + if remote_rank != cur_rank: + iris.store(out_ptr, reduced, cur_rank, remote_rank, heap_bases) + + # Slow path: masked (only boundary tiles land here) + else: + mask = (rm[:, None] < M) & (rn[None, :] < N) + + start_rank = pid % world_size + acc = iris.load(base_ptr, cur_rank, start_rank, heap_bases, mask=mask).to(acc_dtype) + for i in tl.static_range(1, world_size): + remote_rank = (start_rank + i) % world_size + acc += iris.load(base_ptr, cur_rank, remote_rank, heap_bases, mask=mask).to(acc_dtype) + + reduced = acc.to(output_ptr.type.element_ty) + + tl.store(out_ptr, reduced, mask=mask, cache_modifier=".wt") + + for i in tl.static_range(0, world_size): + remote_rank = (start_rank + i) % world_size + if remote_rank != cur_rank: + iris.store(out_ptr, reduced, cur_rank, remote_rank, heap_bases, mask=mask) def all_reduce( @@ -669,6 +685,14 @@ def all_reduce( if config is None: config = Config() + # Check for unsupported options + if config.use_gluon: + raise ValueError( + "all_reduce does not support use_gluon=True. " + "Gluon implementation is not available for all_reduce. " + "Use default config (use_gluon=False)." + ) + rank = shmem.get_rank() world_size = shmem.get_num_ranks() M, N = input_tensor.shape[:2] @@ -820,6 +844,9 @@ def all_reduce( config.num_xcds, config.chunk_size, config.all_reduce_distribution, + num_warps=8, + num_stages=1, + waves_per_eu=1, ) elif variant == VARIANT_ONE_SHOT: persistent_all_reduce_one_shot[(config.comm_sms,)]( diff --git a/iris/ccl/all_to_all.py b/iris/ccl/all_to_all.py index 16fb5210..6e1239e6 100644 --- a/iris/ccl/all_to_all.py +++ b/iris/ccl/all_to_all.py @@ -8,6 +8,7 @@ import triton import triton.language as tl +import torch import iris from .config import Config @@ -79,6 +80,7 @@ def persistent_all_to_all( GROUP_SIZE_M: Group size for M dimension tiling COMM_SMS: Number of SMs for communication NUM_XCDS: Number of XCDs + CHUNK_SIZE: Chunk size for chiplet transform """ pid = tl.program_id(0) @@ -100,118 +102,90 @@ def persistent_all_to_all( tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) - # Compute row and column indices - rm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - rn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + # Compute base indices for this tile + rm_base = pid_m * BLOCK_SIZE_M + rn_base = pid_n * BLOCK_SIZE_N + + # Check if this tile is fully within bounds (no edge cases) + is_full = (rm_base + BLOCK_SIZE_M <= M) & (rn_base + BLOCK_SIZE_N <= N) + + # Build indices (used by both paths) + rm = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) - mask = (rm[:, None] < M) & (rn[None, :] < N) # Pre-compute base offsets for better memory access patterns and vectorization - # Base offset for input rows (M dimension) input_base_m = rm[:, None] * stride_in_m - # Base offset for output rows (M dimension) output_base_m = rm[:, None] * stride_out_m - # Base offset for input columns (N dimension) - will be adjusted per rank input_base_n = rn[None, :] * stride_in_n - # Base offset for output columns (N dimension) - will be adjusted per rank output_base_n = rn[None, :] * stride_out_n - # Process local rank first for better cache locality - # Local path: copy input[cur_rank] chunk to output[cur_rank] chunk - input_offset_local = input_base_m + (input_base_n + cur_rank * N * stride_in_n) - output_offset_local = output_base_m + (output_base_n + cur_rank * N * stride_out_n) - input_ptr_local = input_ptr + input_offset_local - output_ptr_local = output_ptr + output_offset_local - # Vectorization hints for 2D access pattern - input_ptr_local = tl.multiple_of(input_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - output_ptr_local = tl.multiple_of(output_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - - data = tl.load(input_ptr_local, mask=mask) - tl.store(output_ptr_local, data, mask=mask, cache_modifier=".wt") - - # Pre-compute constant parts that don't depend on target_rank - # Base offset for input (without rank-specific column offset) - input_base_offset = input_base_m + input_base_n - # Remote store offset: write into target's output at columns [cur_rank*N : (cur_rank+1)*N] - # This is constant for all target_rank iterations since it only depends on cur_rank - output_offset_remote = output_base_m + (output_base_n + cur_rank * N * stride_out_n) - output_ptr_remote = tl.multiple_of(output_ptr + output_offset_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - - # Pre-compute rank stride for input (N * stride_in_n) - rank_stride_in = N * stride_in_n - - # Traffic shaping: Break each tile into 64x64 sub-blocks and process them - # This creates better memory access patterns and allows hardware to distribute - # traffic across XGMI links based on access patterns - SUB_BLOCK_M: tl.constexpr = 64 - SUB_BLOCK_N: tl.constexpr = 64 - - # Calculate number of 64x64 sub-blocks needed to cover the tile - num_sub_blocks_m = tl.cdiv(BLOCK_SIZE_M, SUB_BLOCK_M) - num_sub_blocks_n = tl.cdiv(BLOCK_SIZE_N, SUB_BLOCK_N) - total_sub_blocks = num_sub_blocks_m * num_sub_blocks_n - - # Base row/column indices for the tile - tile_base_m = pid_m * BLOCK_SIZE_M - tile_base_n = pid_n * BLOCK_SIZE_N - - # Process all remote ranks: load each chunk and scatter to corresponding target - # Each target_rank may have different input data, so we must load separately - for target_rank in range(world_size): - # Skip local rank as it's already processed above - if target_rank != cur_rank: - # Traffic shaping: Process tile in 64x64 sub-blocks - # Loop over all sub-blocks to ensure complete coverage - for sub_block_id in range(total_sub_blocks): - # Calculate sub-block position within the tile - sub_block_m = (sub_block_id // num_sub_blocks_n) * SUB_BLOCK_M - sub_block_n = (sub_block_id % num_sub_blocks_n) * SUB_BLOCK_N - - # Compute row and column indices for this 64x64 sub-block - # Start from tile base and add sub-block offset, then create arrays - sub_rm_base = tile_base_m + sub_block_m - sub_rn_base = tile_base_n + sub_block_n - sub_rm = sub_rm_base + tl.arange(0, SUB_BLOCK_M) - sub_rn = sub_rn_base + tl.arange(0, SUB_BLOCK_N) - - # Create mask for this sub-block - sub_mask = ( - (sub_rm[:, None] < M) - & (sub_rn[None, :] < N) - & (sub_rm[:, None] < (tile_base_m + BLOCK_SIZE_M)) - & (sub_rn[None, :] < (tile_base_n + BLOCK_SIZE_N)) - ) + # Fast path: NO MASKS (full tiles) + if is_full: + # Process local rank first for better cache locality + input_offset_local = input_base_m + (input_base_n + cur_rank * N * stride_in_n) + output_offset_local = output_base_m + (output_base_n + cur_rank * N * stride_out_n) + input_ptr_local = input_ptr + input_offset_local + output_ptr_local = output_ptr + output_offset_local + input_ptr_local = tl.multiple_of(input_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_ptr_local = tl.multiple_of(output_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - # Compute offsets for this sub-block - sub_input_base_m = sub_rm[:, None] * stride_in_m - sub_input_base_n = sub_rn[None, :] * stride_in_n - sub_output_base_m = sub_rm[:, None] * stride_out_m - sub_output_base_n = sub_rn[None, :] * stride_out_n + data = tl.load(input_ptr_local) + tl.store(output_ptr_local, data, cache_modifier=".wt") - # Compute input pointer for this target_rank's chunk (sub-block) - sub_input_offset = sub_input_base_m + (sub_input_base_n + target_rank * N * stride_in_n) - sub_input_ptr_send = input_ptr + sub_input_offset - sub_input_ptr_send = tl.multiple_of(sub_input_ptr_send, (SUB_BLOCK_M, SUB_BLOCK_N)) + # Process all remote ranks + for target_rank in range(world_size): + if target_rank != cur_rank: + input_offset_remote = input_base_m + (input_base_n + target_rank * N * stride_in_n) + output_offset_remote = output_base_m + (output_base_n + cur_rank * N * stride_out_n) + input_ptr_remote = input_ptr + input_offset_remote + output_ptr_remote = output_ptr + output_offset_remote + input_ptr_remote = tl.multiple_of(input_ptr_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_ptr_remote = tl.multiple_of(output_ptr_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + remote_data = tl.load(input_ptr_remote) + iris.store( + output_ptr_remote, + remote_data, + cur_rank, + target_rank, + heap_bases, + ) - # Compute output pointer (sub-block) - sub_output_offset = sub_output_base_m + (sub_output_base_n + cur_rank * N * stride_out_n) - sub_output_ptr_remote = output_ptr + sub_output_offset - sub_output_ptr_remote = tl.multiple_of(sub_output_ptr_remote, (SUB_BLOCK_M, SUB_BLOCK_N)) + # Slow path: masked (only boundary tiles land here) + else: + mask = (rm[:, None] < M) & (rn[None, :] < N) - # Load data chunk for this target rank (64x64 sub-block) - sub_data = tl.load(sub_input_ptr_send, mask=sub_mask) + # Process local rank first for better cache locality + input_offset_local = input_base_m + (input_base_n + cur_rank * N * stride_in_n) + output_offset_local = output_base_m + (output_base_n + cur_rank * N * stride_out_n) + input_ptr_local = input_ptr + input_offset_local + output_ptr_local = output_ptr + output_offset_local + input_ptr_local = tl.multiple_of(input_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_ptr_local = tl.multiple_of(output_ptr_local, (BLOCK_SIZE_M, BLOCK_SIZE_N)) - # Scatter to target rank's output - # Processing in 64x64 sub-blocks creates better memory access patterns - # that allow hardware to distribute traffic across XGMI links + data = tl.load(input_ptr_local, mask=mask) + tl.store(output_ptr_local, data, mask=mask, cache_modifier=".wt") + + # Process all remote ranks + for target_rank in range(world_size): + if target_rank != cur_rank: + input_offset_remote = input_base_m + (input_base_n + target_rank * N * stride_in_n) + output_offset_remote = output_base_m + (output_base_n + cur_rank * N * stride_out_n) + input_ptr_remote = input_ptr + input_offset_remote + output_ptr_remote = output_ptr + output_offset_remote + input_ptr_remote = tl.multiple_of(input_ptr_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + output_ptr_remote = tl.multiple_of(output_ptr_remote, (BLOCK_SIZE_M, BLOCK_SIZE_N)) + + remote_data = tl.load(input_ptr_remote, mask=mask) iris.store( - sub_output_ptr_remote, - sub_data, + output_ptr_remote, + remote_data, cur_rank, target_rank, heap_bases, - mask=sub_mask, + mask=mask, ) diff --git a/iris/ccl/config.py b/iris/ccl/config.py index 48c156c4..c7da52e8 100644 --- a/iris/ccl/config.py +++ b/iris/ccl/config.py @@ -38,6 +38,8 @@ class Config: all_reduce_num_rings: Number of concurrent rings to form in ring-based all-reduce (default: 1) all_reduce_ring_slice_n: Column slice size for ring reduce-scatter/all-gather (default: auto-set to block_size_n // world_size at runtime) + reduce_scatter_variant: Variant for reduce-scatter operation (default: "two_shot") + Only "two_shot" is supported Example: >>> import iris @@ -68,6 +70,7 @@ class Config: all_reduce_distribution: int = 0 all_reduce_num_rings: int = 1 all_reduce_ring_slice_n: int | None = None + reduce_scatter_variant: str = "two_shot" def __post_init__(self): """Validate and auto-detect num_xcds if not set.""" @@ -109,3 +112,7 @@ def __post_init__(self): ) if self.all_reduce_ring_slice_n & (self.all_reduce_ring_slice_n - 1): raise ValueError(f"all_reduce_ring_slice_n must be a power of two, got {self.all_reduce_ring_slice_n}") + + # Validate reduce_scatter_variant + if self.reduce_scatter_variant != "two_shot": + raise ValueError(f"reduce_scatter_variant must be 'two_shot', got '{self.reduce_scatter_variant}'") diff --git a/iris/ccl/reduce_scatter.py b/iris/ccl/reduce_scatter.py new file mode 100644 index 00000000..7402f32b --- /dev/null +++ b/iris/ccl/reduce_scatter.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Reduce-scatter collective communication primitive for Iris. +Uses the two-shot approach: reduce assigned tiles and store only to own rank. +""" + +import triton +import triton.language as tl +import torch +import iris +from .config import Config + + +@triton.jit() +def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr): + if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): + return pid + + local_pid = pid // num_xcds + chunk_idx = local_pid // chunk_size + pos_in_chunk = local_pid % chunk_size + + xcd = pid % num_xcds + new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk + return new_pid + + +@triton.jit() +def persistent_reduce_scatter_two_shot( + input_ptr, + output_ptr, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + COMM_SMS: tl.constexpr, + NUM_XCDS: tl.constexpr, + CHUNK_SIZE: tl.constexpr, + DISTRIBUTION: tl.constexpr, +): + """ + Reduce-scatter using two-shot approach. + + Each rank reduces its assigned tiles from all ranks and stores the result + only to its own output (no broadcast to other ranks). + """ + pid = tl.program_id(0) + + if NUM_XCDS != 1: + pid = chiplet_transform_chunked(pid, COMM_SMS, NUM_XCDS, CHUNK_SIZE) + + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + total_tiles = num_pid_m * num_pid_n + + acc_dtype = tl.float32 if output_ptr.type.element_ty != tl.int8 else tl.int32 + + tiles_per_rank = tl.cdiv(total_tiles, world_size) + if DISTRIBUTION == 0: + start_tile = cur_rank + stride = world_size + remaining = total_tiles - start_tile + remaining = tl.maximum(remaining, 0) + max_tile_offset = tl.cdiv(remaining, stride) + else: + start_tile = cur_rank * tiles_per_rank + stride = 1 + remaining = total_tiles - start_tile + remaining = tl.maximum(remaining, 0) + max_tile_offset = tl.minimum(tiles_per_rank, remaining) + + for tile_offset in range(pid, max_tile_offset, COMM_SMS): + tile_id = start_tile + tile_offset * stride + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm_base = pid_m * BLOCK_SIZE_M + rn_base = pid_n * BLOCK_SIZE_N + + is_full = (rm_base + BLOCK_SIZE_M <= M) & (rn_base + BLOCK_SIZE_N <= N) + + # Build indices (used by both paths) + rm = rm_base + tl.arange(0, BLOCK_SIZE_M) + rn = rn_base + tl.arange(0, BLOCK_SIZE_N) + + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) + + input_offset = rm[:, None] * stride_in_m + rn[None, :] * stride_in_n + output_offset = rm[:, None] * stride_out_m + rn[None, :] * stride_out_n + + base_ptr = input_ptr + input_offset + out_ptr = output_ptr + output_offset + + # Fast path: NO MASKS + if is_full: + start_rank = pid % world_size + acc = iris.load(base_ptr, cur_rank, start_rank, heap_bases).to(acc_dtype) + for i in tl.static_range(1, world_size): + remote_rank = (start_rank + i) % world_size + acc += iris.load(base_ptr, cur_rank, remote_rank, heap_bases).to(acc_dtype) + + reduced = acc.to(output_ptr.type.element_ty) + + # Store only to own rank (no broadcast) + tl.store(out_ptr, reduced, cache_modifier=".wt") + + # Slow path: masked (only boundary tiles land here) + else: + mask = (rm[:, None] < M) & (rn[None, :] < N) + + start_rank = pid % world_size + acc = iris.load(base_ptr, cur_rank, start_rank, heap_bases, mask=mask).to(acc_dtype) + for i in tl.static_range(1, world_size): + remote_rank = (start_rank + i) % world_size + acc += iris.load(base_ptr, cur_rank, remote_rank, heap_bases, mask=mask).to(acc_dtype) + + reduced = acc.to(output_ptr.type.element_ty) + + # Store only to own rank (no broadcast) + tl.store(out_ptr, reduced, mask=mask, cache_modifier=".wt") + + +def reduce_scatter(output_tensor, input_tensor, shmem, config=None, async_op=False): + """ + Internal reduce-scatter collective operation implementation. + + This function is called internally by shmem.ccl.reduce_scatter(). + Users should use the Iris instance method instead: + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + shmem: Iris shmem context + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + if config is None: + config = Config() + + # Check for unsupported options + if config.use_gluon: + raise ValueError( + "reduce_scatter does not support use_gluon=True. " + "Gluon implementation is not available for reduce_scatter. " + "Use default config (use_gluon=False)." + ) + + # Validate that only two_shot variant is used + variant = getattr(config, "reduce_scatter_variant", "two_shot") + if variant != "two_shot": + raise ValueError( + f"reduce_scatter only supports variant='two_shot', got '{variant}'. " + f"Set config.reduce_scatter_variant='two_shot' or use default config." + ) + + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + M, N = input_tensor.shape[:2] + + # Validate output shape matches input shape + if output_tensor.shape[:2] != (M, N): + raise ValueError( + f"Output tensor shape {output_tensor.shape[:2]} does not match input shape {(M, N)}. " + f"For reduce-scatter, output should have the same shape as input." + ) + + stride_in_m, stride_in_n = input_tensor.stride(0), input_tensor.stride(1) + stride_out_m, stride_out_n = output_tensor.stride(0), output_tensor.stride(1) + + heap_bases = shmem.get_heap_bases() + + # Use all_reduce_distribution for tile distribution + distribution = config.all_reduce_distribution + + persistent_reduce_scatter_two_shot[(config.comm_sms,)]( + input_tensor, + output_tensor, + M, + N, + stride_in_m, + stride_in_n, + stride_out_m, + stride_out_n, + heap_bases, + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.swizzle_size, + config.comm_sms, + config.num_xcds, + config.chunk_size, + distribution, + ) + + if not async_op: + shmem.barrier() diff --git a/iris/experimental/iris_gluon.py b/iris/experimental/iris_gluon.py index f9ab82c4..63207943 100644 --- a/iris/experimental/iris_gluon.py +++ b/iris/experimental/iris_gluon.py @@ -578,6 +578,66 @@ def all_to_all(self, output_tensor, input_tensor, config=None, async_op=False): _all_to_all(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def all_gather(self, output_tensor, input_tensor, config=None, async_op=False): + """ + All-gather collective operation. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris_gluon.iris() + >>> # Input: (M, N), Output: (world_size * M, N) + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) + """ + from iris.ccl.all_gather import all_gather as _all_gather + + _all_gather(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + + def reduce_scatter(self, output_tensor, input_tensor, config=None, async_op=False): + """ + Reduce-scatter collective operation. + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris_gluon.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter + + _reduce_scatter(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def _log_with_rank(self, level, message): """Helper method to log with rank information injected into the record.""" extra = {"iris_rank": self.cur_rank, "iris_num_ranks": self.num_ranks} diff --git a/iris/iris.py b/iris/iris.py index 9f6c574e..49257686 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1548,6 +1548,39 @@ def all_to_all(self, output_tensor, input_tensor, config=None, async_op=False): _all_to_all(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def all_gather(self, output_tensor, input_tensor, config=None, async_op=False): + """ + All-gather collective operation. + + Each rank sends its input tensor to all ranks, and all ranks receive + and concatenate all input tensors along dimension 0 (rows), matching + torch.distributed.all_gather_into_tensor behavior. + + Args: + output_tensor: Output tensor of shape (world_size * M, N) - will contain concatenated inputs + input_tensor: Input tensor of shape (M, N) - local rank's data to send + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> # Input: (M, N), Output: (world_size * M, N) + >>> shmem.ccl.all_gather(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(block_size_m=128, block_size_n=32) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, config=config) + + >>> # Async operation (no barrier) + >>> shmem.ccl.all_gather(output_tensor, input_tensor, async_op=True) + """ + from iris.ccl.all_gather import all_gather as _all_gather + + _all_gather(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + def all_reduce_preamble(self, output_tensor, input_tensor, config=None, workspace=None): """ Prepare reusable workspace for all-reduce. @@ -1616,6 +1649,36 @@ def all_reduce(self, output_tensor, input_tensor, config=None, async_op=False, w workspace=workspace, ) + def reduce_scatter(self, output_tensor, input_tensor, config=None, async_op=False): + """ + Reduce-scatter collective operation. + + Each rank reduces its assigned tiles from all ranks' inputs and stores + the result only to its own output tensor. This is similar to all-reduce + but without broadcasting the result to all ranks. + + Args: + output_tensor: Output tensor of shape (M, N) - will contain reduced tiles for this rank + input_tensor: Input tensor of shape (M, N) - local rank's partial data + config: Config instance with kernel parameters (default: None). + If None, uses default Config values. + Only supports reduce_scatter_variant="two_shot". + async_op: If False, performs a barrier at the end. If True, returns immediately. + Default: False. + + Example: + >>> shmem = iris.iris() + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor) + + >>> # Custom configuration + >>> from iris.ccl import Config + >>> config = Config(reduce_scatter_variant="two_shot", all_reduce_distribution=1) + >>> shmem.ccl.reduce_scatter(output_tensor, input_tensor, config=config) + """ + from iris.ccl.reduce_scatter import reduce_scatter as _reduce_scatter + + _reduce_scatter(output_tensor, input_tensor, self._iris, config=config, async_op=async_op) + @triton.jit def __translate(ptr, from_rank, to_rank, heap_bases): @@ -1634,8 +1697,9 @@ def __translate(ptr, from_rank, to_rank, heap_bases): # Optimization to vectorize the load/store # We can't do this in general because we don't know the shape of the tensor - # ptr = tl.max_contiguous(tl.multiple_of(ptr, (64, 64)), (64, 64)) - # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, (64, 64)), (64, 64)) + # ptr = tl.max_contiguous(tl.multiple_of(ptr, (16, 16)), (16, 32)) + translated_ptr = tl.multiple_of(translated_ptr, (32, 32)) + translated_ptr = tl.max_contiguous(translated_ptr, (1, 32)) # ptr = tl.max_contiguous(tl.multiple_of(ptr, 512), 512) # translated_ptr = tl.max_contiguous(tl.multiple_of(translated_ptr, 512), 512) diff --git a/tests/ccl/test_all_gather.py b/tests/ccl/test_all_gather.py new file mode 100644 index 00000000..ae649043 --- /dev/null +++ b/tests/ccl/test_all_gather.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +""" +Test suite for all-gather collective operation. +""" + +import pytest +import torch +import torch.distributed as dist +import iris +from iris.ccl import Config + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float16, + torch.float32, + torch.bfloat16, + ], +) +@pytest.mark.parametrize( + "M, N", + [ + (128, 64), # Small + (1024, 256), # Medium + (8192, 8192), # Large + ], +) +def test_all_gather(dtype, M, N): + """Test all-gather functionality by comparing against PyTorch's implementation.""" + # Ensure torch.distributed is initialized (should be done by test runner) + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 # 8GB + shmem = iris.iris(heap_size) + rank = shmem.get_rank() + world_size = shmem.get_num_ranks() + + # PyTorch's all_gather_into_tensor format: each rank has M x N input + # Output is (world_size * M, N) - concatenated along dimension 0 + pytorch_input_tensor = torch.randn(M, N, dtype=dtype, device=f"cuda:{rank}") + # Fill with deterministic values for easier debugging + pytorch_input_tensor.fill_(float(rank + 1)) + + # Create output tensor for PyTorch: (world_size * M, N) + pytorch_output_tensor = torch.zeros(world_size * M, N, dtype=dtype, device=f"cuda:{rank}") + + # Run PyTorch's all_gather_into_tensor to get reference output + shmem.barrier() + dist.all_gather_into_tensor(pytorch_output_tensor, pytorch_input_tensor) + torch.cuda.synchronize() + + # Now set up Iris all_gather format + # Iris format: same as PyTorch - input is (M, N), output is (world_size * M, N) + iris_input_tensor = shmem.zeros((M, N), dtype=dtype) + iris_input_tensor.copy_(pytorch_input_tensor) + + iris_output_tensor = shmem.zeros((world_size * M, N), dtype=dtype) + + # Run Iris all_gather + shmem.barrier() + config = Config() + shmem.ccl.all_gather(iris_output_tensor, iris_input_tensor, config=config) + torch.cuda.synchronize() + + # Compare results + atol = 1e-3 if dtype == torch.float16 else 1e-5 + max_diff = torch.abs(iris_output_tensor - pytorch_output_tensor).max().item() + + try: + assert torch.allclose(iris_output_tensor, pytorch_output_tensor, atol=atol), ( + f"Max difference: {max_diff}, expected < {atol}\n" + f"Rank {rank}: Iris output doesn't match PyTorch's all_gather_into_tensor" + ) + finally: + # Final barrier to ensure all ranks complete before test cleanup + # This helps with test isolation when running multiple tests + # Note: shmem.barrier() already does cuda.synchronize() + shmem.barrier() + # Explicitly delete the shmem instance to trigger cleanup + del shmem + # Force garbage collection to ensure IPC handles are cleaned up + import gc + + gc.collect()