Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ __pycache__/
.pytest_cache/
.coverage
htmlcov/
.DS_Store
14 changes: 14 additions & 0 deletions benchmark/benchmark_spsv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Run SpSV benchmark. From project root: python benchmark/benchmark_spsv.py [--synthetic | --csv-csr out.csv]."""

import sys
from pathlib import Path

root = Path(__file__).resolve().parent.parent
if str(root) not in sys.path:
sys.path.insert(0, str(root))

from tests.test_spsv import main


if __name__ == "__main__":
main()
43 changes: 3 additions & 40 deletions src/flagsparse/sparse_operations/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,14 @@
cp = None
cpx_sparse = None

_TORCH_COMPLEX32_DTYPE = getattr(torch, "complex32", None)
if _TORCH_COMPLEX32_DTYPE is None:
_TORCH_COMPLEX32_DTYPE = getattr(torch, "chalf", None)

_SUPPORTED_VALUE_DTYPES = [
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
torch.complex64,
torch.complex128,
]
if _TORCH_COMPLEX32_DTYPE is not None:
_SUPPORTED_VALUE_DTYPES.append(_TORCH_COMPLEX32_DTYPE)
_SUPPORTED_VALUE_DTYPES.extend([torch.complex64, torch.complex128])
SUPPORTED_VALUE_DTYPES = tuple(_SUPPORTED_VALUE_DTYPES)
SUPPORTED_INDEX_DTYPES = (torch.int32, torch.int64)
_INDEX_LIMIT_INT32 = 2**31 - 1
Expand All @@ -40,8 +35,6 @@
"SUPPORTED_VALUE_DTYPES",
"SUPPORTED_INDEX_DTYPES",
"_INDEX_LIMIT_INT32",
"_torch_complex32_dtype",
"_is_complex32_dtype",
"_is_complex_dtype",
"_resolve_scatter_value_dtype",
"_component_dtype_for_complex",
Expand All @@ -68,17 +61,8 @@
"tl",
)


def _torch_complex32_dtype():
return _TORCH_COMPLEX32_DTYPE


def _is_complex32_dtype(value_dtype):
return _TORCH_COMPLEX32_DTYPE is not None and value_dtype == _TORCH_COMPLEX32_DTYPE


def _is_complex_dtype(value_dtype):
return _is_complex32_dtype(value_dtype) or value_dtype in (torch.complex64, torch.complex128)
return value_dtype in (torch.complex64, torch.complex128)


def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"):
Expand All @@ -95,24 +79,13 @@ def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"):
"complex64": torch.complex64,
"complex128": torch.complex128,
}
if token == "complex32":
if _TORCH_COMPLEX32_DTYPE is not None:
return _TORCH_COMPLEX32_DTYPE, False, None
if dtype_policy == "strict":
raise TypeError("complex32 is unavailable in this torch build")
return torch.complex64, True, "complex32 is unavailable; fallback to complex64"
if token not in mapping:
raise TypeError(f"Unsupported dtype token: {value_dtype}")
value_dtype = mapping[token]
if _is_complex32_dtype(value_dtype):
# If complex32 exists in torch dtype table, keep native path.
return value_dtype, False, None
return value_dtype, False, None


def _component_dtype_for_complex(value_dtype):
if _is_complex32_dtype(value_dtype):
return torch.float16
if value_dtype == torch.complex64:
return torch.float32
if value_dtype == torch.complex128:
Expand All @@ -123,8 +96,6 @@ def _component_dtype_for_complex(value_dtype):
def _tolerance_for_dtype(value_dtype):
if value_dtype == torch.float16:
return 2e-3, 2e-3
if _is_complex32_dtype(value_dtype):
return 5e-3, 5e-3
if value_dtype == torch.bfloat16:
return 1e-1, 1e-1
if value_dtype in (torch.float32, torch.complex64):
Expand Down Expand Up @@ -155,9 +126,6 @@ def _cupy_dtype_from_torch(torch_dtype):
torch.int32: cp.int32,
torch.int64: cp.int64,
}
if _TORCH_COMPLEX32_DTYPE is not None:
# CuPy has no native complex32 sparse path; use complex64 for baseline parity.
mapping[_TORCH_COMPLEX32_DTYPE] = cp.complex64
if torch_dtype not in mapping:
raise TypeError(f"Unsupported dtype conversion to CuPy: {torch_dtype}")
return mapping[torch_dtype]
Expand Down Expand Up @@ -193,8 +161,6 @@ def _to_backend_like(torch_tensor, ref_obj):
def _cusparse_baseline_skip_reason(value_dtype):
if value_dtype == torch.bfloat16:
return "bfloat16 is not supported by the cuSPARSE baseline path; skipped"
if _is_complex32_dtype(value_dtype):
return "complex32 is not supported by the cuSPARSE baseline path; skipped"
if cp is None and value_dtype == torch.float16:
return "float16 is not supported by torch sparse fallback when CuPy is unavailable; skipped"
return None
Expand All @@ -203,9 +169,6 @@ def _cusparse_baseline_skip_reason(value_dtype):
def _build_random_dense(dense_size, value_dtype, device):
if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64):
return torch.randn(dense_size, dtype=value_dtype, device=device)
if _is_complex32_dtype(value_dtype):
stacked = torch.randn((dense_size, 2), dtype=torch.float16, device=device)
return torch.view_as_complex(stacked)
if _is_complex_dtype(value_dtype):
component_dtype = _component_dtype_for_complex(value_dtype)
real = torch.randn(dense_size, dtype=component_dtype, device=device)
Expand Down
55 changes: 13 additions & 42 deletions src/flagsparse/sparse_operations/gather_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ def _launch_triton_scatter_kernel(
)


def _validate_gather_value_dtype(dense_vector, op_name):
return None


def _cusparse_spmv(selector_matrix, dense_vector):
if cp is not None and cpx_sparse is not None and isinstance(selector_matrix, cpx_sparse.spmatrix):
if torch.is_tensor(dense_vector):
Expand Down Expand Up @@ -330,6 +334,7 @@ def flagsparse_gather(a, indices, out=None, mode="raise", block_size=1024, retur
dense_vector, dense_backend = _to_torch_tensor(a, "a")
indices_tensor, _ = _to_torch_tensor(indices, "indices")
dense_vector, indices_tensor, kernel_indices = _prepare_inputs(dense_vector, indices_tensor)
_validate_gather_value_dtype(dense_vector, "flagsparse_gather")

torch.cuda.synchronize()
start_time = time.perf_counter()
Expand Down Expand Up @@ -479,6 +484,7 @@ def pytorch_index_scatter(
def cusparse_spmv_gather(dense_vector, indices, selector_matrix=None):
"""Equivalent gather baseline via cuSPARSE-backed COO SpMV."""
dense_vector, indices, _ = _prepare_inputs(dense_vector, indices)
_validate_gather_value_dtype(dense_vector, "cusparse_spmv_gather")
skip_reason = _cusparse_baseline_skip_reason(dense_vector.dtype)
if skip_reason:
raise RuntimeError(skip_reason)
Expand Down Expand Up @@ -557,21 +563,14 @@ def _cupy_gather_detect_layout(dense_vector):
return "scalar16"
if dense_vector.ndim == 1 and dense_vector.dtype == torch.complex64:
return "complex64"
if (
dense_vector.ndim == 2
and dense_vector.shape[1] == 2
and dense_vector.dtype == torch.float16
):
# complex32 alignment with half2 storage.
return "complex16_pair"
raise TypeError(
"Unsupported gather input format. Expected one of: "
"1D float16/bfloat16, 1D complex64, or 2D (N,2) float16."
"1D float16/bfloat16 or 1D complex64."
)


def _cupy_gather_dense_size(dense_vector, layout):
if layout in ("scalar16", "complex64", "complex16_pair"):
if layout in ("scalar16", "complex64"):
return int(dense_vector.shape[0])
raise RuntimeError(f"Unknown gather layout: {layout}")

Expand Down Expand Up @@ -601,7 +600,7 @@ def _cupy_gather_validate_inputs(dense_vector, indices):

def _cupy_gather_validate_combo(dense_vector, indices, layout):
# Keep only the required extra gather combos:
# Half+Int64, Bfloat16+Int32/Int64, Complex32+Int32/Int64, Complex64+Int32/Int64
# Half+Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64
if layout == "scalar16":
if dense_vector.dtype == torch.float16:
if indices.dtype != torch.int64:
Expand All @@ -611,11 +610,6 @@ def _cupy_gather_validate_combo(dense_vector, indices, layout):
return
raise TypeError("scalar16 gather_cupy supports only float16/bfloat16")

if layout == "complex16_pair":
if dense_vector.dtype != torch.float16:
raise TypeError("complex16_pair gather_cupy supports only float16 pairs")
return

if layout == "complex64":
return

Expand All @@ -625,8 +619,6 @@ def _cupy_gather_validate_combo(dense_vector, indices, layout):
def _cupy_gather_layout_raw_kind(layout):
if layout == "scalar16":
return 16
if layout == "complex16_pair":
return 32
if layout == "complex64":
return 64
raise RuntimeError(f"Unknown gather layout: {layout}")
Expand Down Expand Up @@ -675,9 +667,6 @@ def _cupy_gather_dense_to_raw_torch(dense_t, layout):
return dense_t.reshape(-1).view(torch.uint16)
if layout == "complex64":
return dense_t.reshape(-1).view(torch.uint64)
if layout == "complex16_pair":
lanes_u16 = dense_t.reshape(-1).view(torch.uint16)
return lanes_u16.reshape(-1, 2).view(torch.uint32).reshape(-1)
raise RuntimeError(f"Unknown gather layout: {layout}")


Expand All @@ -686,44 +675,30 @@ def _cupy_gather_raw_to_dense_torch(out_raw_t, layout, dense_t_dtype):
return out_raw_t.view(dense_t_dtype).reshape(-1)
if layout == "complex64":
return out_raw_t.view(torch.complex64).reshape(-1)
if layout == "complex16_pair":
lanes_u16 = out_raw_t.view(torch.uint16).reshape(-1, 2)
return lanes_u16.view(dense_t_dtype).reshape(-1, 2)
raise RuntimeError(f"Unknown gather layout: {layout}")


def _cupy_gather_empty(layout, dense_dtype, device):
if layout in ("scalar16", "complex64"):
return torch.empty(0, dtype=dense_dtype, device=device)
if layout == "complex16_pair":
return torch.empty((0, 2), dtype=dense_dtype, device=device)
raise RuntimeError(f"Unknown gather layout: {layout}")


def _cupy_gather_selector_dtype(layout, dense_dtype):
if layout in ("scalar16", "complex16_pair"):
if layout == "scalar16":
return dense_dtype
if layout == "complex64":
return torch.complex64
raise RuntimeError(f"Unknown gather layout: {layout}")


def _cupy_gather_prepare_dense(dense_vector, indices):
runtime_dense = dense_vector
restore_mode = None
native_complex32 = _torch_complex32_dtype()
if native_complex32 is not None and dense_vector.ndim == 1 and dense_vector.dtype == native_complex32:
runtime_dense = torch.view_as_real(dense_vector).contiguous()
restore_mode = "native_complex32"

layout, dense_size = _cupy_gather_validate_inputs(runtime_dense, indices)
_cupy_gather_validate_combo(runtime_dense, indices, layout)
return runtime_dense, layout, dense_size, restore_mode
layout, dense_size = _cupy_gather_validate_inputs(dense_vector, indices)
_cupy_gather_validate_combo(dense_vector, indices, layout)
return dense_vector, layout, dense_size, None


def _cupy_gather_restore_output(gathered_t, restore_mode):
if restore_mode == "native_complex32":
return torch.view_as_complex(gathered_t.contiguous())
return gathered_t


Expand Down Expand Up @@ -880,10 +855,6 @@ def cusparse_spmv_gather_cupy(dense_vector, indices, selector_matrix=None):
start_time = time.perf_counter()
if layout in ("scalar16", "complex64"):
gathered_t = _cusparse_spmv(selector_matrix, runtime_dense_t)
elif layout == "complex16_pair":
gathered_real = _cusparse_spmv(selector_matrix, runtime_dense_t[:, 0])
gathered_imag = _cusparse_spmv(selector_matrix, runtime_dense_t[:, 1])
gathered_t = torch.stack([gathered_real, gathered_imag], dim=1)
else:
raise RuntimeError(f"Unknown gather layout: {layout}")
torch.cuda.synchronize()
Expand Down
Loading