-
Notifications
You must be signed in to change notification settings - Fork 27
iris.x: Device-side communication + .x APIs.
#296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: muhosama/ccl-more
Are you sure you want to change the base?
Conversation
Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: mawad-amd <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces iris.x, a new module providing device-side tile-level primitives for fine-grained collective operations. Unlike iris.ccl which handles full tensors with internal tiling, iris.x provides composable functions that users can call from their own kernels to manage tile iteration themselves.
Key Changes:
- New
iris.xmodule with tile-level communication primitives (all-reduce, all-gather, all-to-all, reduce-scatter) - Fused GEMM+Communication operations requiring tritonBLAS (gemm_all_reduce, gemm_all_gather, etc.)
- Comprehensive test suite for new primitives in
tests/x/ - CI/CD modernization with unified workflow replacing 3 separate workflows
- Documentation updates and benchmark enhancements
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
iris/x/__init__.py |
Module initialization exposing all tile-level primitives with optional GEMM operations |
iris/x/all_reduce.py |
Five all-reduce variants (atomic, one-shot, two-shot, spinlock, ring) for different use cases |
iris/x/all_gather.py |
Tile-level all-gather primitive for gathering data from all ranks |
iris/x/all_to_all.py |
Tile-level all-to-all primitive for bidirectional data exchange |
iris/x/reduce_scatter.py |
Tile-level reduce-scatter that reduces and scatters to assigned ranks |
iris/x/gemm_all_reduce.py |
Fused GEMM + all-reduce using tritonBLAS stages |
iris/x/gemm_all_gather.py |
Fused GEMM + all-gather combining computation and communication |
iris/x/gemm_reduce_scatter.py |
Fused GEMM + reduce-scatter for column-parallel workloads |
iris/x/all_gather_gemm.py |
Fused all-gather + GEMM for tensor-parallel workloads |
iris/x/common.py |
Shared utilities for tile indexing and offset computation |
tests/x/test_*.py |
Comprehensive test suite validating all primitives against PyTorch references |
.github/workflows/iris-tests.yml |
New unified test workflow supporting multiple test directories and install methods |
.github/scripts/run_tests.sh |
Updated test runner with tritonBLAS installation for iris.x tests |
tests/ccl/test_all_reduce.py |
Modified to add explicit preamble calls for better test isolation |
pyproject.toml |
Added optional gemm dependency group for tritonBLAS |
docs/reference/examples.md |
Updated documentation with new example references |
benchmark/ccl/all_to_all/benchmark.py |
Added RCCL comparison benchmarking option |
| tl.assume(stride_cn > 0) | ||
|
|
||
| # Determine accumulator dtype based on output type | ||
| acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent accumulator dtype logic compared to other GEMM kernels.
In gemm_all_reduce.py and gemm_reduce_scatter.py, the logic is:
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32But here it's:
acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32This appears to be inverted. For int8 output types, you'd want int32 accumulator (correct in other files), but for floating point types, you'd want float32 accumulator. The condition should be != tl.int8 for consistency.
| acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 | |
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 |
| tl.assume(stride_ag_n > 0) | ||
|
|
||
| # Determine accumulator dtype based on output type | ||
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same accumulator dtype logic inconsistency as in gemm_all_gather.py. The condition should be != tl.int8 instead of == tl.int8 to match the logic in gemm_all_reduce.py and gemm_reduce_scatter.py.
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 | |
| acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 |
| bias_vector = tl.load( | ||
| bias_ptr + row_indices * stride_bias, mask=row_indices < M, other=0.0 | ||
| ) | ||
| acc = add_vector(acc, bias_vector[:, None], QUANTIZED=False) |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent bias vector broadcasting. In gemm_all_reduce.py line 172, the code uses:
acc = add_vector(acc, bias_vector, QUANTIZED=False)But here it uses:
acc = add_vector(acc, bias_vector[:, None], QUANTIZED=False)The [:, None] adds an extra dimension. For consistency across GEMM kernels, the bias handling should be uniform. Based on typical GEMM bias usage where bias is (M,) and acc is (M, N), the [:, None] expansion may be needed, but all GEMM kernels should handle it the same way.
|
|
||
| # K_local is the local shard size (K = world_size * K_local) | ||
| # A_sharded has shape (M, K_local), A_gathered has shape (M, K) | ||
| K_local = K // world_size |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable K_local is not used.
| K_local = K // world_size |
| # Launch all_gather_gemm kernel | ||
| num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M | ||
| num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N | ||
| total_tiles = num_pid_m * num_pid_n |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable total_tiles is not used.
| total_tiles = num_pid_m * num_pid_n |
| num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | ||
|
|
||
| # Each rank processes tiles assigned to it (striding distribution) | ||
| tiles_per_rank = tl.cdiv(total_tiles, world_size) |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable tiles_per_rank is not used.
|
|
||
| import triton | ||
| import triton.language as tl | ||
| import torch |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'torch' is not used.
| import torch |
|
|
||
| import triton | ||
| import triton.language as tl | ||
| import torch |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'torch' is not used.
| import torch |
|
|
||
| import triton | ||
| import triton.language as tl | ||
| import torch |
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'torch' is not used.
| import torch |
| import torch | ||
|
|
Copilot
AI
Dec 9, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import of 'torch' is not used.
| import torch |
Introduces
iris.xAPIs.Submission Checklist