-
Notifications
You must be signed in to change notification settings - Fork 27
iris.ccl.all_gather, iris.ccl.reduce_scatter + some updates to ccl benchmarking.
#295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds the all-gather collective communication primitive to Iris, updates default configuration values, and upgrades the Docker environment to ROCm 7.1 with the latest Triton version. The implementation includes comprehensive tests and benchmarks, matching PyTorch's all_gather_into_tensor behavior.
Key Changes:
- Implements all-gather collective with kernel implementation and API integration
- Updates default
comm_smsfrom 32 to 64 in Config class - Upgrades Docker base image to ROCm 7.1, Ubuntu 24.04, Python 3.13, and PyTorch 2.9.1
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| iris/ccl/all_gather.py | New all-gather collective implementation with persistent kernel and tensor concatenation along dimension 0 |
| iris/experimental/iris_gluon.py | Adds all_gather method to CCL interface with consistent API design |
| tests/ccl/test_all_gather.py | Comprehensive test suite validating against PyTorch's implementation across multiple dtypes and sizes |
| benchmark/ccl/all_gather/benchmark.py | Complete benchmark implementation with bandwidth reporting and optional RCCL comparison |
| iris/ccl/all_to_all.py | Minor cleanup: removes redundant comment and adds torch import for consistency |
| benchmark/ccl/all_to_all/benchmark.py | Refactors to use comm_ranks variable and updates default comm_sms to 64 |
| docker/Dockerfile | Upgrades to ROCm 7.1, Ubuntu 24.04, Python 3.13, PyTorch 2.9.1, and latest Triton; temporarily disables rocprofiler-systems |
| .gitignore | Adds Python cache directories and compiled file patterns |
|
|
||
| """ | ||
| All-gather collective communication primitive for Iris. | ||
| Gathers tensors from all ranks and concatenates them along the last dimension. |
Copilot
AI
Dec 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring incorrectly states that tensors are concatenated "along the last dimension" when they are actually concatenated along dimension 0 (the first dimension/rows), as correctly described elsewhere in the file and in the actual implementation.
| Gathers tensors from all ranks and concatenates them along the last dimension. | |
| Gathers tensors from all ranks and concatenates them along dimension 0 (rows). |
|
|
||
| import triton | ||
| import triton.language as tl | ||
| import torch |
Copilot
AI
Dec 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'torch' is not used.
|
|
||
| import triton | ||
| import triton.language as tl | ||
| import torch |
Copilot
AI
Dec 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'torch' is not used.
| import torch |
iris.ccl.all_gather, iris.ccl.reduce_scatter + some updates to ccl benchmarking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 13 out of 14 changed files in this pull request and generated 13 comments.
| rn = rn_base + tl.arange(0, BLOCK_SIZE_N) | ||
| rm_input = tl.max_contiguous(tl.multiple_of(rm_input, BLOCK_SIZE_M), BLOCK_SIZE_M) | ||
| rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N) | ||
|
|
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] There's trailing whitespace at the end of line 100. This should be removed for consistency with coding standards.
| # So validation is more complex - we'll just check that outputs are non-zero and correct pattern | ||
|
|
||
| # Basic validation: check that output contains reduced values (sum of inputs for assigned tiles) | ||
| # Since tile assignment is complex, we'll use a simpler check: | ||
| # The sum of all outputs should equal the sum of all inputs (scaled by world_size for each tile location) | ||
|
|
||
| # For now, validate that output is not all zeros and has expected magnitude | ||
| output_sum = output_tensor.sum().item() | ||
| input_sum = input_tensor.sum().item() | ||
|
|
||
| # Expected: each tile location gets sum of all ranks' contributions | ||
| # Total sum of all outputs should equal world_size * sum of one input (since each location is reduced) | ||
| # Actually, in our two-shot implementation, each rank reduces its assigned tiles | ||
| # The sum across all ranks' outputs for their assigned tiles should equal sum of all inputs | ||
| total_expected_sum = world_size * input_sum # Each tile gets sum of all ranks | ||
|
|
||
| # Simple validation: output should be non-zero and have reasonable values | ||
| atol = 1e-3 if datatype == torch.float16 else 1e-5 | ||
| has_data = output_tensor.abs().max().item() > atol | ||
|
|
||
| if not has_data: | ||
| shmem.error(f"Rank {rank}: Validation failed - output is all zeros") | ||
| success = False | ||
| else: | ||
| # Check that values are in expected range (sum of inputs from all ranks for assigned tiles) | ||
| # The exact validation depends on tile assignment, so we do a basic sanity check | ||
| success = True | ||
| shmem.info(f"Rank {rank}: Output sum: {output_sum:.2f}, Input sum: {input_sum:.2f}") |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation logic in the reduce_scatter benchmark is incomplete and potentially misleading. The comments indicate the complexity of validating tile-based reduction, but the actual validation only checks if the output is non-zero and always sets success = True at line 253 regardless of the actual correctness. A proper validation should either compare against PyTorch's reduce_scatter_tensor result (accounting for the difference in semantics) or use a more rigorous correctness check.
| # So validation is more complex - we'll just check that outputs are non-zero and correct pattern | |
| # Basic validation: check that output contains reduced values (sum of inputs for assigned tiles) | |
| # Since tile assignment is complex, we'll use a simpler check: | |
| # The sum of all outputs should equal the sum of all inputs (scaled by world_size for each tile location) | |
| # For now, validate that output is not all zeros and has expected magnitude | |
| output_sum = output_tensor.sum().item() | |
| input_sum = input_tensor.sum().item() | |
| # Expected: each tile location gets sum of all ranks' contributions | |
| # Total sum of all outputs should equal world_size * sum of one input (since each location is reduced) | |
| # Actually, in our two-shot implementation, each rank reduces its assigned tiles | |
| # The sum across all ranks' outputs for their assigned tiles should equal sum of all inputs | |
| total_expected_sum = world_size * input_sum # Each tile gets sum of all ranks | |
| # Simple validation: output should be non-zero and have reasonable values | |
| atol = 1e-3 if datatype == torch.float16 else 1e-5 | |
| has_data = output_tensor.abs().max().item() > atol | |
| if not has_data: | |
| shmem.error(f"Rank {rank}: Validation failed - output is all zeros") | |
| success = False | |
| else: | |
| # Check that values are in expected range (sum of inputs from all ranks for assigned tiles) | |
| # The exact validation depends on tile assignment, so we do a basic sanity check | |
| success = True | |
| shmem.info(f"Rank {rank}: Output sum: {output_sum:.2f}, Input sum: {input_sum:.2f}") | |
| # So validation is more complex - we'll compare the output for the assigned region | |
| # Prepare input list for PyTorch reduce_scatter_tensor | |
| pytorch_input_list = [torch.zeros_like(pytorch_input) for _ in range(world_size)] | |
| for r in range(world_size): | |
| pytorch_input_list[r].fill_(float(r + 1)) | |
| pytorch_input_cat = torch.stack(pytorch_input_list, dim=0) # shape: (world_size, M, N) | |
| pytorch_input_cat = pytorch_input_cat.view(world_size * M, N) | |
| # PyTorch reduce_scatter expects input of shape (world_size * chunk, N) | |
| chunk_size = M // world_size | |
| pytorch_output = torch.zeros(chunk_size, N, dtype=datatype, device=f"cuda:{rank}") | |
| dist.reduce_scatter_tensor(pytorch_output, pytorch_input_cat, op=dist.ReduceOp.SUM) | |
| torch.cuda.synchronize() | |
| # Now, compare the output_tensor for the region assigned to this rank | |
| # Assume tile assignment is contiguous chunks for validation | |
| iris_output_chunk = output_tensor[rank * chunk_size : (rank + 1) * chunk_size, :] | |
| atol = 1e-3 if datatype == torch.float16 else 1e-5 | |
| rtol = 1e-2 if datatype == torch.float16 else 1e-4 | |
| if not torch.allclose(iris_output_chunk, pytorch_output, atol=atol, rtol=rtol): | |
| shmem.error(f"Rank {rank}: Validation failed - output does not match PyTorch reference") | |
| # Optionally, print some diagnostics | |
| diff = (iris_output_chunk - pytorch_output).abs().max().item() | |
| shmem.error(f"Rank {rank}: Max abs diff: {diff:.6f}") | |
| success = False | |
| else: | |
| success = True | |
| shmem.info(f"Rank {rank}: Output matches PyTorch reference (max abs diff < {atol})") |
| parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping") | ||
|
|
||
| return vars(parser.parse_args()) | ||
|
|
||
|
|
||
| def _worker(local_rank: int, world_size: int, init_url: str, args: dict): | ||
| """Worker function for PyTorch distributed execution.""" | ||
| backend = "nccl" if torch.cuda.is_available() else "gloo" | ||
| dist.init_process_group( | ||
| backend=backend, | ||
| init_method=init_url, | ||
| world_size=world_size, | ||
| rank=local_rank, | ||
| device_id=torch.device(f"cuda:{local_rank}"), | ||
| ) | ||
|
|
||
| # Use Gluon if requested and available | ||
| if args.get("use_gluon", False): | ||
| if not GLUON_AVAILABLE: | ||
| raise RuntimeError("Gluon is not available. Install Triton with Gluon support or remove --use_gluon flag") | ||
| shmem = iris_gluon.iris(args["heap_size"]) | ||
| else: | ||
| shmem = iris.iris(args["heap_size"]) | ||
|
|
||
| rank = shmem.get_rank() | ||
| world_size = shmem.get_num_ranks() | ||
|
|
||
| # Datatype mapping | ||
| datatype = torch.float32 | ||
| if args["datatype"] == "fp16": | ||
| datatype = torch.float16 | ||
| elif args["datatype"] == "fp32": | ||
| datatype = torch.float32 | ||
| elif args["datatype"] == "bf16": | ||
| datatype = torch.bfloat16 | ||
| else: | ||
| print("Unknown datatype.") | ||
| exit(1) | ||
|
|
||
| M = args["m"] | ||
| N = args["n"] | ||
|
|
||
| # Create config with optional block size parameters | ||
| config_kwargs = {"comm_sms": args["comm_sms"]} | ||
| if args["block_size_m"] is not None: | ||
| config_kwargs["block_size_m"] = args["block_size_m"] | ||
| if args["block_size_n"] is not None: | ||
| config_kwargs["block_size_n"] = args["block_size_n"] | ||
| if args["swizzle_size"] is not None: | ||
| config_kwargs["swizzle_size"] = args["swizzle_size"] | ||
| if args["num_xcds"] is not None: | ||
| config_kwargs["num_xcds"] = args["num_xcds"] | ||
| if args.get("use_gluon", False): | ||
| config_kwargs["use_gluon"] = True |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The benchmark allows setting use_gluon=True via command-line argument (line 69, 121-122), but the all_gather operation explicitly raises a ValueError when use_gluon=True is passed in the config (see iris/ccl/all_gather.py lines 178-183). This will cause the benchmark to fail at runtime when --use_gluon is specified. Either remove the use_gluon option from this benchmark or clarify in the help text that it's not supported for all_gather.
| @triton.jit() | ||
| def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr): | ||
| if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): | ||
| return pid | ||
|
|
||
| local_pid = pid // num_xcds | ||
| chunk_idx = local_pid // chunk_size | ||
| pos_in_chunk = local_pid % chunk_size | ||
|
|
||
| xcd = pid % num_xcds | ||
| new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk | ||
| return new_pid | ||
|
|
||
|
|
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The chiplet_transform_chunked function is duplicated across multiple files: all_gather.py, all_reduce.py, all_to_all.py, and reduce_scatter.py. This is code duplication that should be addressed by extracting it into a shared utility module (e.g., iris/ccl/utils.py) to improve maintainability and reduce redundancy.
| @triton.jit() | |
| def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr): | |
| if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): | |
| return pid | |
| local_pid = pid // num_xcds | |
| chunk_idx = local_pid // chunk_size | |
| pos_in_chunk = local_pid % chunk_size | |
| xcd = pid % num_xcds | |
| new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk | |
| return new_pid | |
| from iris.ccl.utils import chiplet_transform_chunked |
|
|
||
| """ | ||
| All-gather collective communication primitive for Iris. | ||
| Gathers tensors from all ranks and concatenates them along the last dimension. |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The module docstring states "concatenates them along the last dimension" but the implementation actually concatenates along dimension 0 (rows), not the last dimension. This is inconsistent with the docstring. The implementation correctly matches PyTorch's all_gather_into_tensor behavior which concatenates along dimension 0.
| Gathers tensors from all ranks and concatenates them along the last dimension. | |
| Gathers tensors from all ranks and concatenates them along dimension 0 (rows), matching PyTorch's all_gather_into_tensor behavior. |
| @triton.jit() | ||
| def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr): | ||
| if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): | ||
| return pid | ||
|
|
||
| local_pid = pid // num_xcds | ||
| chunk_idx = local_pid // chunk_size | ||
| pos_in_chunk = local_pid % chunk_size | ||
|
|
||
| xcd = pid % num_xcds | ||
| new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk | ||
| return new_pid | ||
|
|
||
|
|
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The chiplet_transform_chunked function is duplicated across multiple files: all_gather.py, all_reduce.py, all_to_all.py, and reduce_scatter.py. This is code duplication that should be addressed by extracting it into a shared utility module (e.g., iris/ccl/utils.py) to improve maintainability and reduce redundancy.
| @triton.jit() | |
| def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr): | |
| if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size): | |
| return pid | |
| local_pid = pid // num_xcds | |
| chunk_idx = local_pid // chunk_size | |
| pos_in_chunk = local_pid % chunk_size | |
| xcd = pid % num_xcds | |
| new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk | |
| return new_pid | |
| from .utils import chiplet_transform_chunked |
| # Note: Must use shmem.zeros() to allocate on Iris symmetric heap for iris.load() compatibility | ||
| input_tensor = shmem.zeros((M, N), dtype=datatype) | ||
| output_tensor = shmem.zeros((M, N), dtype=datatype) | ||
| expected_tensor = shmem.zeros((M, N), dtype=datatype) |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable expected_tensor is not used.
| expected_tensor = shmem.zeros((M, N), dtype=datatype) |
| # Total sum of all outputs should equal world_size * sum of one input (since each location is reduced) | ||
| # Actually, in our two-shot implementation, each rank reduces its assigned tiles | ||
| # The sum across all ranks' outputs for their assigned tiles should equal sum of all inputs | ||
| total_expected_sum = world_size * input_sum # Each tile gets sum of all ranks |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable total_expected_sum is not used.
|
|
||
| import triton | ||
| import triton.language as tl | ||
| import torch |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'torch' is not used.
| import torch |
|
|
||
| import triton | ||
| import triton.language as tl | ||
| import torch |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'torch' is not used.
| import torch |
Submission Checklist