Skip to content

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Dec 9, 2025

Introduces iris.x APIs.

Submission Checklist

Copilot AI and others added 5 commits December 8, 2025 06:33
@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Dec 9, 2025
@neoblizz neoblizz changed the base branch from main to muhosama/ccl-more December 9, 2025 20:32
@neoblizz neoblizz requested a review from Copilot December 9, 2025 20:33
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces iris.x, a new module providing device-side tile-level primitives for fine-grained collective operations. Unlike iris.ccl which handles full tensors with internal tiling, iris.x provides composable functions that users can call from their own kernels to manage tile iteration themselves.

Key Changes:

  • New iris.x module with tile-level communication primitives (all-reduce, all-gather, all-to-all, reduce-scatter)
  • Fused GEMM+Communication operations requiring tritonBLAS (gemm_all_reduce, gemm_all_gather, etc.)
  • Comprehensive test suite for new primitives in tests/x/
  • CI/CD modernization with unified workflow replacing 3 separate workflows
  • Documentation updates and benchmark enhancements

Reviewed changes

Copilot reviewed 33 out of 33 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
iris/x/__init__.py Module initialization exposing all tile-level primitives with optional GEMM operations
iris/x/all_reduce.py Five all-reduce variants (atomic, one-shot, two-shot, spinlock, ring) for different use cases
iris/x/all_gather.py Tile-level all-gather primitive for gathering data from all ranks
iris/x/all_to_all.py Tile-level all-to-all primitive for bidirectional data exchange
iris/x/reduce_scatter.py Tile-level reduce-scatter that reduces and scatters to assigned ranks
iris/x/gemm_all_reduce.py Fused GEMM + all-reduce using tritonBLAS stages
iris/x/gemm_all_gather.py Fused GEMM + all-gather combining computation and communication
iris/x/gemm_reduce_scatter.py Fused GEMM + reduce-scatter for column-parallel workloads
iris/x/all_gather_gemm.py Fused all-gather + GEMM for tensor-parallel workloads
iris/x/common.py Shared utilities for tile indexing and offset computation
tests/x/test_*.py Comprehensive test suite validating all primitives against PyTorch references
.github/workflows/iris-tests.yml New unified test workflow supporting multiple test directories and install methods
.github/scripts/run_tests.sh Updated test runner with tritonBLAS installation for iris.x tests
tests/ccl/test_all_reduce.py Modified to add explicit preamble calls for better test isolation
pyproject.toml Added optional gemm dependency group for tritonBLAS
docs/reference/examples.md Updated documentation with new example references
benchmark/ccl/all_to_all/benchmark.py Added RCCL comparison benchmarking option

tl.assume(stride_cn > 0)

# Determine accumulator dtype based on output type
acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent accumulator dtype logic compared to other GEMM kernels.

In gemm_all_reduce.py and gemm_reduce_scatter.py, the logic is:

acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32

But here it's:

acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32

This appears to be inverted. For int8 output types, you'd want int32 accumulator (correct in other files), but for floating point types, you'd want float32 accumulator. The condition should be != tl.int8 for consistency.

Suggested change
acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32

Copilot uses AI. Check for mistakes.
tl.assume(stride_ag_n > 0)

# Determine accumulator dtype based on output type
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same accumulator dtype logic inconsistency as in gemm_all_gather.py. The condition should be != tl.int8 instead of == tl.int8 to match the logic in gemm_all_reduce.py and gemm_reduce_scatter.py.

Suggested change
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32
acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32

Copilot uses AI. Check for mistakes.
bias_vector = tl.load(
bias_ptr + row_indices * stride_bias, mask=row_indices < M, other=0.0
)
acc = add_vector(acc, bias_vector[:, None], QUANTIZED=False)
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent bias vector broadcasting. In gemm_all_reduce.py line 172, the code uses:

acc = add_vector(acc, bias_vector, QUANTIZED=False)

But here it uses:

acc = add_vector(acc, bias_vector[:, None], QUANTIZED=False)

The [:, None] adds an extra dimension. For consistency across GEMM kernels, the bias handling should be uniform. Based on typical GEMM bias usage where bias is (M,) and acc is (M, N), the [:, None] expansion may be needed, but all GEMM kernels should handle it the same way.

Copilot uses AI. Check for mistakes.

# K_local is the local shard size (K = world_size * K_local)
# A_sharded has shape (M, K_local), A_gathered has shape (M, K)
K_local = K // world_size
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable K_local is not used.

Suggested change
K_local = K // world_size

Copilot uses AI. Check for mistakes.
# Launch all_gather_gemm kernel
num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
total_tiles = num_pid_m * num_pid_n
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable total_tiles is not used.

Suggested change
total_tiles = num_pid_m * num_pid_n

Copilot uses AI. Check for mistakes.
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)

# Each rank processes tiles assigned to it (striding distribution)
tiles_per_rank = tl.cdiv(total_tiles, world_size)
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable tiles_per_rank is not used.

Copilot uses AI. Check for mistakes.

import triton
import triton.language as tl
import torch
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'torch' is not used.

Suggested change
import torch

Copilot uses AI. Check for mistakes.

import triton
import triton.language as tl
import torch
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'torch' is not used.

Suggested change
import torch

Copilot uses AI. Check for mistakes.

import triton
import triton.language as tl
import torch
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'torch' is not used.

Suggested change
import torch

Copilot uses AI. Check for mistakes.
Comment on lines +13 to +14
import torch

Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'torch' is not used.

Suggested change
import torch

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

2 participants