Skip to content

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Dec 8, 2025

  • All-gather
  • Reduce-scatter
  • Updated defaults
  • Bump docker to ROCm7.1 + Latest triton

Submission Checklist

Copilot AI review requested due to automatic review settings December 8, 2025 18:40
@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Dec 8, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds the all-gather collective communication primitive to Iris, updates default configuration values, and upgrades the Docker environment to ROCm 7.1 with the latest Triton version. The implementation includes comprehensive tests and benchmarks, matching PyTorch's all_gather_into_tensor behavior.

Key Changes:

  • Implements all-gather collective with kernel implementation and API integration
  • Updates default comm_sms from 32 to 64 in Config class
  • Upgrades Docker base image to ROCm 7.1, Ubuntu 24.04, Python 3.13, and PyTorch 2.9.1

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
iris/ccl/all_gather.py New all-gather collective implementation with persistent kernel and tensor concatenation along dimension 0
iris/experimental/iris_gluon.py Adds all_gather method to CCL interface with consistent API design
tests/ccl/test_all_gather.py Comprehensive test suite validating against PyTorch's implementation across multiple dtypes and sizes
benchmark/ccl/all_gather/benchmark.py Complete benchmark implementation with bandwidth reporting and optional RCCL comparison
iris/ccl/all_to_all.py Minor cleanup: removes redundant comment and adds torch import for consistency
benchmark/ccl/all_to_all/benchmark.py Refactors to use comm_ranks variable and updates default comm_sms to 64
docker/Dockerfile Upgrades to ROCm 7.1, Ubuntu 24.04, Python 3.13, PyTorch 2.9.1, and latest Triton; temporarily disables rocprofiler-systems
.gitignore Adds Python cache directories and compiled file patterns


"""
All-gather collective communication primitive for Iris.
Gathers tensors from all ranks and concatenates them along the last dimension.
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring incorrectly states that tensors are concatenated "along the last dimension" when they are actually concatenated along dimension 0 (the first dimension/rows), as correctly described elsewhere in the file and in the actual implementation.

Suggested change
Gathers tensors from all ranks and concatenates them along the last dimension.
Gathers tensors from all ranks and concatenates them along dimension 0 (rows).

Copilot uses AI. Check for mistakes.

import triton
import triton.language as tl
import torch
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'torch' is not used.

Copilot uses AI. Check for mistakes.

import triton
import triton.language as tl
import torch
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'torch' is not used.

Suggested change
import torch

Copilot uses AI. Check for mistakes.
@neoblizz neoblizz changed the title Adding more collectives iris.ccl.all_gather, iris.ccl.reduce_scatter + some updates to ccl benchmarking. Dec 9, 2025
@neoblizz neoblizz requested a review from Copilot December 9, 2025 07:40
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 14 changed files in this pull request and generated 13 comments.

rn = rn_base + tl.arange(0, BLOCK_SIZE_N)
rm_input = tl.max_contiguous(tl.multiple_of(rm_input, BLOCK_SIZE_M), BLOCK_SIZE_M)
rn = tl.max_contiguous(tl.multiple_of(rn, BLOCK_SIZE_N), BLOCK_SIZE_N)

Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] There's trailing whitespace at the end of line 100. This should be removed for consistency with coding standards.

Suggested change

Copilot uses AI. Check for mistakes.
Comment on lines +227 to +254
# So validation is more complex - we'll just check that outputs are non-zero and correct pattern

# Basic validation: check that output contains reduced values (sum of inputs for assigned tiles)
# Since tile assignment is complex, we'll use a simpler check:
# The sum of all outputs should equal the sum of all inputs (scaled by world_size for each tile location)

# For now, validate that output is not all zeros and has expected magnitude
output_sum = output_tensor.sum().item()
input_sum = input_tensor.sum().item()

# Expected: each tile location gets sum of all ranks' contributions
# Total sum of all outputs should equal world_size * sum of one input (since each location is reduced)
# Actually, in our two-shot implementation, each rank reduces its assigned tiles
# The sum across all ranks' outputs for their assigned tiles should equal sum of all inputs
total_expected_sum = world_size * input_sum # Each tile gets sum of all ranks

# Simple validation: output should be non-zero and have reasonable values
atol = 1e-3 if datatype == torch.float16 else 1e-5
has_data = output_tensor.abs().max().item() > atol

if not has_data:
shmem.error(f"Rank {rank}: Validation failed - output is all zeros")
success = False
else:
# Check that values are in expected range (sum of inputs from all ranks for assigned tiles)
# The exact validation depends on tile assignment, so we do a basic sanity check
success = True
shmem.info(f"Rank {rank}: Output sum: {output_sum:.2f}, Input sum: {input_sum:.2f}")
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation logic in the reduce_scatter benchmark is incomplete and potentially misleading. The comments indicate the complexity of validating tile-based reduction, but the actual validation only checks if the output is non-zero and always sets success = True at line 253 regardless of the actual correctness. A proper validation should either compare against PyTorch's reduce_scatter_tensor result (accounting for the difference in semantics) or use a more rigorous correctness check.

Suggested change
# So validation is more complex - we'll just check that outputs are non-zero and correct pattern
# Basic validation: check that output contains reduced values (sum of inputs for assigned tiles)
# Since tile assignment is complex, we'll use a simpler check:
# The sum of all outputs should equal the sum of all inputs (scaled by world_size for each tile location)
# For now, validate that output is not all zeros and has expected magnitude
output_sum = output_tensor.sum().item()
input_sum = input_tensor.sum().item()
# Expected: each tile location gets sum of all ranks' contributions
# Total sum of all outputs should equal world_size * sum of one input (since each location is reduced)
# Actually, in our two-shot implementation, each rank reduces its assigned tiles
# The sum across all ranks' outputs for their assigned tiles should equal sum of all inputs
total_expected_sum = world_size * input_sum # Each tile gets sum of all ranks
# Simple validation: output should be non-zero and have reasonable values
atol = 1e-3 if datatype == torch.float16 else 1e-5
has_data = output_tensor.abs().max().item() > atol
if not has_data:
shmem.error(f"Rank {rank}: Validation failed - output is all zeros")
success = False
else:
# Check that values are in expected range (sum of inputs from all ranks for assigned tiles)
# The exact validation depends on tile assignment, so we do a basic sanity check
success = True
shmem.info(f"Rank {rank}: Output sum: {output_sum:.2f}, Input sum: {input_sum:.2f}")
# So validation is more complex - we'll compare the output for the assigned region
# Prepare input list for PyTorch reduce_scatter_tensor
pytorch_input_list = [torch.zeros_like(pytorch_input) for _ in range(world_size)]
for r in range(world_size):
pytorch_input_list[r].fill_(float(r + 1))
pytorch_input_cat = torch.stack(pytorch_input_list, dim=0) # shape: (world_size, M, N)
pytorch_input_cat = pytorch_input_cat.view(world_size * M, N)
# PyTorch reduce_scatter expects input of shape (world_size * chunk, N)
chunk_size = M // world_size
pytorch_output = torch.zeros(chunk_size, N, dtype=datatype, device=f"cuda:{rank}")
dist.reduce_scatter_tensor(pytorch_output, pytorch_input_cat, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
# Now, compare the output_tensor for the region assigned to this rank
# Assume tile assignment is contiguous chunks for validation
iris_output_chunk = output_tensor[rank * chunk_size : (rank + 1) * chunk_size, :]
atol = 1e-3 if datatype == torch.float16 else 1e-5
rtol = 1e-2 if datatype == torch.float16 else 1e-4
if not torch.allclose(iris_output_chunk, pytorch_output, atol=atol, rtol=rtol):
shmem.error(f"Rank {rank}: Validation failed - output does not match PyTorch reference")
# Optionally, print some diagnostics
diff = (iris_output_chunk - pytorch_output).abs().max().item()
shmem.error(f"Rank {rank}: Max abs diff: {diff:.6f}")
success = False
else:
success = True
shmem.info(f"Rank {rank}: Output matches PyTorch reference (max abs diff < {atol})")

Copilot uses AI. Check for mistakes.
Comment on lines +69 to +122
parser.add_argument("--use_gluon", action="store_true", help="Use Gluon implementation with traffic shaping")

return vars(parser.parse_args())


def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
"""Worker function for PyTorch distributed execution."""
backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(
backend=backend,
init_method=init_url,
world_size=world_size,
rank=local_rank,
device_id=torch.device(f"cuda:{local_rank}"),
)

# Use Gluon if requested and available
if args.get("use_gluon", False):
if not GLUON_AVAILABLE:
raise RuntimeError("Gluon is not available. Install Triton with Gluon support or remove --use_gluon flag")
shmem = iris_gluon.iris(args["heap_size"])
else:
shmem = iris.iris(args["heap_size"])

rank = shmem.get_rank()
world_size = shmem.get_num_ranks()

# Datatype mapping
datatype = torch.float32
if args["datatype"] == "fp16":
datatype = torch.float16
elif args["datatype"] == "fp32":
datatype = torch.float32
elif args["datatype"] == "bf16":
datatype = torch.bfloat16
else:
print("Unknown datatype.")
exit(1)

M = args["m"]
N = args["n"]

# Create config with optional block size parameters
config_kwargs = {"comm_sms": args["comm_sms"]}
if args["block_size_m"] is not None:
config_kwargs["block_size_m"] = args["block_size_m"]
if args["block_size_n"] is not None:
config_kwargs["block_size_n"] = args["block_size_n"]
if args["swizzle_size"] is not None:
config_kwargs["swizzle_size"] = args["swizzle_size"]
if args["num_xcds"] is not None:
config_kwargs["num_xcds"] = args["num_xcds"]
if args.get("use_gluon", False):
config_kwargs["use_gluon"] = True
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The benchmark allows setting use_gluon=True via command-line argument (line 69, 121-122), but the all_gather operation explicitly raises a ValueError when use_gluon=True is passed in the config (see iris/ccl/all_gather.py lines 178-183). This will cause the benchmark to fail at runtime when --use_gluon is specified. Either remove the use_gluon option from this benchmark or clarify in the help text that it's not supported for all_gather.

Copilot uses AI. Check for mistakes.
Comment on lines +16 to +29
@triton.jit()
def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr):
if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size):
return pid

local_pid = pid // num_xcds
chunk_idx = local_pid // chunk_size
pos_in_chunk = local_pid % chunk_size

xcd = pid % num_xcds
new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk
return new_pid


Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chiplet_transform_chunked function is duplicated across multiple files: all_gather.py, all_reduce.py, all_to_all.py, and reduce_scatter.py. This is code duplication that should be addressed by extracting it into a shared utility module (e.g., iris/ccl/utils.py) to improve maintainability and reduce redundancy.

Suggested change
@triton.jit()
def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr):
if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size):
return pid
local_pid = pid // num_xcds
chunk_idx = local_pid // chunk_size
pos_in_chunk = local_pid % chunk_size
xcd = pid % num_xcds
new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk
return new_pid
from iris.ccl.utils import chiplet_transform_chunked

Copilot uses AI. Check for mistakes.

"""
All-gather collective communication primitive for Iris.
Gathers tensors from all ranks and concatenates them along the last dimension.
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The module docstring states "concatenates them along the last dimension" but the implementation actually concatenates along dimension 0 (rows), not the last dimension. This is inconsistent with the docstring. The implementation correctly matches PyTorch's all_gather_into_tensor behavior which concatenates along dimension 0.

Suggested change
Gathers tensors from all ranks and concatenates them along the last dimension.
Gathers tensors from all ranks and concatenates them along dimension 0 (rows), matching PyTorch's all_gather_into_tensor behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +16 to +29
@triton.jit()
def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr):
if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size):
return pid

local_pid = pid // num_xcds
chunk_idx = local_pid // chunk_size
pos_in_chunk = local_pid % chunk_size

xcd = pid % num_xcds
new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk
return new_pid


Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The chiplet_transform_chunked function is duplicated across multiple files: all_gather.py, all_reduce.py, all_to_all.py, and reduce_scatter.py. This is code duplication that should be addressed by extracting it into a shared utility module (e.g., iris/ccl/utils.py) to improve maintainability and reduce redundancy.

Suggested change
@triton.jit()
def chiplet_transform_chunked(pid, num_workgroups: tl.constexpr, num_xcds: tl.constexpr, chunk_size: tl.constexpr):
if pid > (num_workgroups // (num_xcds * chunk_size)) * (num_xcds * chunk_size):
return pid
local_pid = pid // num_xcds
chunk_idx = local_pid // chunk_size
pos_in_chunk = local_pid % chunk_size
xcd = pid % num_xcds
new_pid = chunk_idx * num_xcds * chunk_size + xcd * chunk_size + pos_in_chunk
return new_pid
from .utils import chiplet_transform_chunked

Copilot uses AI. Check for mistakes.
# Note: Must use shmem.zeros() to allocate on Iris symmetric heap for iris.load() compatibility
input_tensor = shmem.zeros((M, N), dtype=datatype)
output_tensor = shmem.zeros((M, N), dtype=datatype)
expected_tensor = shmem.zeros((M, N), dtype=datatype)
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable expected_tensor is not used.

Suggested change
expected_tensor = shmem.zeros((M, N), dtype=datatype)

Copilot uses AI. Check for mistakes.
# Total sum of all outputs should equal world_size * sum of one input (since each location is reduced)
# Actually, in our two-shot implementation, each rank reduces its assigned tiles
# The sum across all ranks' outputs for their assigned tiles should equal sum of all inputs
total_expected_sum = world_size * input_sum # Each tile gets sum of all ranks
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable total_expected_sum is not used.

Copilot uses AI. Check for mistakes.

import triton
import triton.language as tl
import torch
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'torch' is not used.

Suggested change
import torch

Copilot uses AI. Check for mistakes.

import triton
import triton.language as tl
import torch
Copy link

Copilot AI Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'torch' is not used.

Suggested change
import torch

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Feature]: GEMM + ReduceScatter [Feature]: Implement All Gather + GEMM Fused Kernels

2 participants