Skip to content

[Benchmark] Add all gather matmul benchmark #400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: joydddd/stack/21
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations

from .all_gather_matmul import AGMatmulBench as AGMatmulBench
from .all_reduce import AllReduceBench as AllReduceBench
129 changes: 129 additions & 0 deletions benchmarks/distributed/all_gather_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from __future__ import annotations

import argparse

import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

from .experiment_util import BenchmarkOperator
from .experiment_util import ExperimentConfig

BUILDIN_SHAPES = [
(256, 256, 256),
(384, 384, 384),
(512, 512, 512),
(640, 640, 640),
(768, 768, 768),
(896, 896, 896),
(1024, 1024, 1024),
(1152, 1152, 1152),
(1280, 1280, 1280),
(1408, 1408, 1408),
(1536, 1536, 1536),
(1664, 1664, 1664),
(1792, 1792, 1792),
(1920, 1920, 1920),
(2048, 2048, 2048),
(2176, 2176, 2176),
(2304, 2304, 2304),
(2432, 2432, 2432),
(2560, 2560, 2560),
(2688, 2688, 2688),
(2816, 2816, 2816),
(2944, 2944, 2944),
(3072, 3072, 3072),
(3200, 3200, 3200),
(3328, 3328, 3328),
(3456, 3456, 3456),
(3584, 3584, 3584),
(3712, 3712, 3712),
(3840, 3840, 3840),
(3968, 3968, 3968),
(4096, 4096, 4096),
]


class AGMatmulBench(BenchmarkOperator):
def gen_configs(self, args: argparse.Namespace) -> list[ExperimentConfig]:
all_configs = []
for sz in args.shape:
all_configs.append(
ExperimentConfig(
shape=sz,
dtype=args.dtype,
backends=args.backend,
device=self.device,
)
)

return all_configs

def gen_inputs(self, config: ExperimentConfig) -> tuple:
M, N, K = config.shape
a = symm_mem.empty(
(M, K),
dtype=config.dtype,
device=config.device,
)
b = (
torch.randn((K, N), device=config.device, dtype=config.dtype)
.T.contiguous()
.T
)
assert dist.group.WORLD is not None
symm_mem.rendezvous(a, dist.group.WORLD.group_name)
return (a, b)

def additional_parser_args(
self, parser: argparse.ArgumentParser
) -> argparse.ArgumentParser:
def matmul_shape_type(s: str) -> tuple[int, int, int]:
try:
M, N, K = map(int, s.split(","))
return M, N, K
except Exception as e:
raise argparse.ArgumentTypeError(
"Matmul shape must be M, N, K. (M, K) @ (K, N) -> (M, N)"
) from e

parser.add_argument(
"--shape",
type=matmul_shape_type,
nargs="+",
default=BUILDIN_SHAPES,
help="matmul shapes: (M, N, K). (M, K) @ (K, N) -> (M, N)",
)
return parser

def __init__(self) -> None:
self.op_name = "ag_matmul"
self.baseline = "nccl"
super().__init__()

def nccl_mem_ag_mm(
a_shared: torch.Tensor, b: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
from torch.distributed._functional_collectives import all_gather_tensor

a_gathered = all_gather_tensor(a_shared, 0, "0")
return a_gathered, torch.matmul(a_gathered, b)

def torch_symm_mem_ag_mm(
a_shared: torch.Tensor, b: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert dist.group.WORLD is not None
a_gathered, c = torch.ops.symm_mem.fused_all_gather_matmul(
a_shared, [b], gather_dim=0, group_name=dist.group.WORLD.group_name
)
return a_gathered, c[0]

assert dist.group.WORLD is not None

AG_MATMUL_DICT = {
"nccl": nccl_mem_ag_mm,
"torch_symm_mem": torch_symm_mem_ag_mm,
"helion": ("examples.all_gather_matmul", "helion_all_gather_matmul"),
"kraken": ("kraken.all_gather", "all_gather_matmul"),
}
self.backend_dict = AG_MATMUL_DICT
36 changes: 28 additions & 8 deletions benchmarks/distributed/experiment_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def clone_symm_mem_tensor(tensor: torch.Tensor) -> torch.Tensor:
device=tensor.device,
)
assert dist.group.WORLD is not None
symm_mem.rendezvous(symm_mem_tensor, dist.group.WORLD.group_name)
try:
symm_mem.rendezvous(symm_mem_tensor, dist.group.WORLD.group_name)
except RuntimeError as e:
raise RuntimeError(
f"Failed to rendezvous tensor symmetric memory tensor of shape {tensor.shape}. "
) from e
symm_mem_tensor.copy_(tensor)
return symm_mem_tensor

Expand Down Expand Up @@ -68,7 +73,7 @@ class ExperimentConfig:
device: Target device for the experiment, defaults to None (auto-detected)
"""

shape: tuple[int]
shape: tuple[int, ...]
dtype: torch.dtype
backends: list[str]
device: torch.device | None = None
Expand Down Expand Up @@ -145,7 +150,7 @@ class BenchmarkOperator:
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 \
benchmarks/run_distributed.py
benchmarks/run_distributed.py <op>
"""

experiments: list[Experiment]
Expand Down Expand Up @@ -207,6 +212,12 @@ def _parse_args(self) -> argparse.Namespace:
description=f"Run benchmark for {self.__name__}. " + self.help_str
)

parser.add_argument(
"op",
type=str,
help="Operator to benchmark. ",
)

parser.add_argument(
"--backend",
type=str,
Expand All @@ -229,6 +240,8 @@ def _parse_args(self) -> argparse.Namespace:
self.args = parser.parse_args()
self.args.dtype = getattr(torch, self.args.dtype)

assert self.args.op == self.op_name

return self.args

def __init__(self) -> None:
Expand All @@ -244,7 +257,6 @@ def __init__(self) -> None:

self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
dist.init_process_group("nccl")
torch.manual_seed(42 + self.local_rank)

self.experiments = []
Expand Down Expand Up @@ -405,35 +417,43 @@ def get_results(self, metric: str = "speedup") -> defaultdict | None:

def _run_experiment(self, config: ExperimentConfig) -> dict[str, float]:
if self.baseline not in config.backends:
backends = config.backends.append(self.baseline)
backends = [*config.backends, self.baseline]
else:
backends = config.backends

gloden_inp = self.gen_inputs(config)
inputs = {backend: clone_inputs(gloden_inp) for backend in backends} # pyright: ignore[reportOptionalIterable]

gloden_fn = self.fn_dict[self.baseline]
assert gloden_fn is not None

inp_og = clone_inputs(gloden_inp)
gloden_o = gloden_fn(*gloden_inp)

results = {}
for backend in backends: # pyright: ignore[reportOptionalIterable]
for backend in backends:
fn = self.fn_dict[backend]
if fn is None:
results[backend] = float("nan")
continue
inp = inputs[backend]
inp = clone_inputs(inp_og)
target_fn = functools.partial(fn, *inp)
try:
test_o = target_fn()
except RuntimeError:
results[backend] = float("nan")
continue
except AssertionError:
results[backend] = float("nan")
continue
torch.testing.assert_close(test_o, gloden_o, atol=1e-1, rtol=1e-1)

results[backend] = benchmark_distributed(
target_fn, profile_ranks=[self.MASTER_RANK]
)
del test_o
del inp

del gloden_inp
del gloden_o

return results
42 changes: 38 additions & 4 deletions benchmarks/run_distributed.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,47 @@
from __future__ import annotations

from benchmarks.distributed import AllReduceBench as AllReduceBenchmark
import sys

from benchmarks.distributed import AGMatmulBench as AGMatmulBench
from benchmarks.distributed import AllReduceBench as AllReduceBench
import torch.distributed as dist

OP_BENCH = {
"allreduce": AllReduceBench,
"ag_matmul": AGMatmulBench,
}


def main() -> None:
bench = AllReduceBenchmark()
bench.run()
bench.print_results(metric="time_us")
try:
dist.init_process_group("nccl")
except ValueError:
print("""
Failed to initialize process group. Are you running with torchrun?
run distributed benchmark with:
torchrun \
--nnodes 1 --nproc-per-node 8 \
--rdzv-backend c10d --rdzv-endpoint localhost:0 \
--no_python python3 \
benchmarks/run_distributed.py <op>
""")
sys.exit(1)

if len(sys.argv) < 2:
print("Usage: python3 benchmarks/run_distributed.py <op>")
print(f"Available ops: {OP_BENCH.keys()}")
sys.exit(1)

op = sys.argv[1]

if op not in OP_BENCH:
print(f"Unknown op: {op}")
print(f"value ops: {OP_BENCH.keys()}")
sys.exit(1)

op_bench = OP_BENCH[op]()
op_bench.run()
op_bench.print_results(metric="time_us")

dist.destroy_process_group()

Expand Down
29 changes: 17 additions & 12 deletions examples/all_gather_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def copy_engine_all_gather_w_progress(
backend_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(backend_stream):
for step in range(world_size):
src_rank = (rank + step + 1) % world_size
src_rank = (rank + step) % world_size
for split_id in range(splits_per_rank):
src_buf = symm_mem_hdl.get_buffer(
src_rank, chunks[0].shape, inp.dtype, chunks[0].numel() * split_id
Expand All @@ -81,7 +81,9 @@ def copy_engine_all_gather_w_progress(
block_sizes=[128, 256, 64],
num_warps=8,
num_stages=3,
indexing="block_ptr",
indexing="tensor_descriptor",
pid_type="persistent_interleaved",
l2_groupings=[4],
),
static_shapes=True,
)
Expand All @@ -90,7 +92,7 @@ def helion_matmul_w_progress(
a_shared: torch.Tensor,
b: torch.Tensor,
progress: torch.Tensor,
SPLITS_PER_RANK: int,
SPLITS_PER_RANK: hl.constexpr,
RANK: int,
) -> torch.Tensor:
"""
Expand All @@ -114,16 +116,19 @@ def helion_matmul_w_progress(
M_per_rank = a_shared.size(0)
for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
hl.wait(
progress,
[
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
],
signal=1,
)
# TODO(joydddd): natively support starting range from non_zero index.
comm_block_id = ((tile_m.begin + RANK * M_per_rank) % M) // (
M_per_rank // SPLITS_PER_RANK
) # pyright: ignore[reportOperatorIssue]
hl.wait(progress, [comm_block_id], signal=1)
for tile_k in hl.tile(K):
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
out[tile_m, tile_n] = acc
# TODO(joydddd): use a_shared and skip barrier when data is available on local rank.
acc = torch.addmm(
acc,
a[(tile_m.index + RANK * M_per_rank) % M, tile_k],
b[tile_k, tile_n],
)
out[(tile_m.index + RANK * M_per_rank) % M, tile_n] = acc
return out


Expand Down
Loading