From 3f70f442fcd846829cddef3fd0a6df724f6c4028 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 13 Apr 2026 15:25:56 +0800 Subject: [PATCH 01/22] test --- src/flagsparse/sparse_operations/spsv.py | 428 ++++++++++++++++--- tests/pytest/test_spsv_csr_accuracy.py | 192 ++++++++- tests/test_spsv.py | 501 +++++++++++++++++++++-- 3 files changed, 1041 insertions(+), 80 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index cdfbf50..2976ad1 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -11,11 +11,27 @@ torch.bfloat16, torch.float32, torch.float64, + *((_torch_complex32_dtype(),) if _torch_complex32_dtype() is not None else ()), + torch.complex64, ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) SPSV_NON_TRANS_PRIMARY_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), + *(((_torch_complex32_dtype(), torch.int32),) if _torch_complex32_dtype() is not None else ()), + (torch.complex64, torch.int32), +) +SPSV_NON_TRANS_EXTENDED_COMBOS = ( + (torch.float32, torch.int64), + (torch.float64, torch.int64), + *(((_torch_complex32_dtype(), torch.int64),) if _torch_complex32_dtype() is not None else ()), + (torch.complex64, torch.int64), +) +SPSV_TRANS_PRIMARY_COMBOS = ( + (torch.float32, torch.int32), + (torch.float64, torch.int32), + *(((_torch_complex32_dtype(), torch.int32),) if _torch_complex32_dtype() is not None else ()), + (torch.complex64, torch.int32), ) SPSV_PROMOTE_FP32_TO_FP64 = str( os.environ.get("FLAGSPARSE_SPSV_PROMOTE_FP32_TO_FP64", "0") @@ -46,15 +62,40 @@ def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" if (data_dtype, index_dtype) in SPSV_NON_TRANS_PRIMARY_COMBOS: return + if (data_dtype, index_dtype) in SPSV_NON_TRANS_EXTENDED_COMBOS: + return if data_dtype == torch.bfloat16 and index_dtype == torch.int32: return raise TypeError( - f"{fmt_name} SpSV currently supports NON_TRANS combinations with int32 kernel " - "indices: (float32, int32), (float64, int32), (bfloat16, int32)" + f"{fmt_name} SpSV currently supports NON_TRANS combinations: " + "(float32, int32/int64), (float64, int32/int64), " + "(complex32, int32/int64), (complex64, int32/int64), (bfloat16, int32)" ) -def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): +def _validate_spsv_trans_combo(data_dtype, index_dtype, fmt_name): + if (data_dtype, index_dtype) in SPSV_TRANS_PRIMARY_COMBOS: + return + raise TypeError( + f"{fmt_name} SpSV currently supports TRANS combinations with int32 indices only: " + "(float32, int32), (float64, int32), (complex32, int32), (complex64, int32)" + ) + + +def _normalize_spsv_transpose_mode(transpose): + if isinstance(transpose, bool): + return "T" if transpose else "N" + token = str(transpose).strip().upper() + if token in ("N", "NON", "NON_TRANS"): + return "N" + if token in ("T", "TRANS"): + return "T" + raise ValueError( + "transpose must be bool or one of: N/NON/NON_TRANS, T/TRANS" + ) + + +def _prepare_spsv_inputs(data, indices, indptr, b, shape): """Validate and normalize inputs for sparse solve A x = b with CSR A.""" if not all(torch.is_tensor(t) for t in (data, indices, indptr, b)): raise TypeError("data, indices, indptr, b must all be torch.Tensor") @@ -77,7 +118,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( - "data dtype must be one of: bfloat16, float32, float64" + "data dtype must be one of: bfloat16, float32, float64, complex32, complex64" ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -85,8 +126,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): raise TypeError("indptr dtype must be torch.int32 or torch.int64") if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") - if transpose: - raise NotImplementedError("transpose=True is not implemented in Triton SpSV yet") indices64 = indices.to(torch.int64).contiguous() indptr64 = indptr.to(torch.int64).contiguous() @@ -94,8 +133,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): raise ValueError( f"int64 index value {int(indices64.max().item())} exceeds Triton int32 kernel range" ) - _validate_spsv_non_trans_combo(data.dtype, torch.int32, "CSR") - if indptr64.numel() > 0: if int(indptr64[0].item()) != 0: raise ValueError("indptr[0] must be 0") @@ -112,6 +149,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): return ( data.contiguous(), + indices.dtype, indices64, indptr64, b.contiguous(), @@ -120,6 +158,21 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape, transpose=False): ) +def _promote_complex32_spsv_inputs(data, b): + if _is_complex32_dtype(data.dtype): + return data.to(torch.complex64), b.to(torch.complex64), data.dtype + return data, b, None + + +def _restore_complex32_spsv_output(x, target_dtype): + if _is_complex32_dtype(target_dtype): + limit = 65504.0 + real = torch.clamp(x.real, min=-limit, max=limit).to(torch.float16) + imag = torch.clamp(x.imag, min=-limit, max=limit).to(torch.float16) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return x.to(target_dtype) + + @triton.jit def _spsv_csr_level_kernel( data_ptr, @@ -172,6 +225,104 @@ def _spsv_csr_level_kernel( tl.store(x_ptr + row, x_row) +@triton.jit +def _spsv_csr_level_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + b_ri_ptr, + x_ri_ptr, + rows_ptr, + n_level_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_level_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + + if USE_FP64_ACC: + acc_re = tl.zeros((1,), dtype=tl.float64) + acc_im = tl.zeros((1,), dtype=tl.float64) + diag_re = tl.zeros((1,), dtype=tl.float64) + diag_im = tl.zeros((1,), dtype=tl.float64) + else: + acc_re = tl.zeros((1,), dtype=tl.float32) + acc_im = tl.zeros((1,), dtype=tl.float32) + diag_re = tl.zeros((1,), dtype=tl.float32) + diag_im = tl.zeros((1,), dtype=tl.float32) + + if UNIT_DIAG: + diag_re = diag_re + 1.0 + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + x_re = tl.load(x_ri_ptr + col * 2, mask=mask, other=0.0) + x_im = tl.load(x_ri_ptr + col * 2 + 1, mask=mask, other=0.0) + + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + + if LOWER: + solved = col < row + else: + solved = col > row + is_diag = col == row + + prod_re = a_re * x_re - a_im * x_im + prod_im = a_re * x_im + a_im * x_re + acc_re = acc_re + tl.sum(tl.where(mask & solved, prod_re, 0.0)) + acc_im = acc_im + tl.sum(tl.where(mask & solved, prod_im, 0.0)) + + if not UNIT_DIAG: + diag_re = diag_re + tl.sum(tl.where(mask & is_diag, a_re, 0.0)) + diag_im = diag_im + tl.sum(tl.where(mask & is_diag, a_im, 0.0)) + + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + + num_re = rhs_re - acc_re + num_im = rhs_im - acc_im + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + + x_re_out = (num_re * diag_re + num_im * diag_im) / den_safe + x_im_out = (num_im * diag_re - num_re * diag_im) / den_safe + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + + offs1 = tl.arange(0, 1) + tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) + tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + + @triton.jit def _spsv_coo_level_kernel_real( data_ptr, @@ -360,6 +511,66 @@ def _triton_spsv_csr_vector( return x +def _triton_spsv_csr_vector_complex( + data, + indices, + indptr, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + levels=None, + block_nnz_use=None, + max_segments_use=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + if levels is None: + levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + data_ri = torch.view_as_real(data.contiguous()).reshape(-1).contiguous() + b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() + component_dtype = _component_dtype_for_complex(data.dtype) + use_fp64 = component_dtype == torch.float64 + if component_dtype == torch.float16: + x_ri_work = torch.zeros((n_rows, 2), dtype=torch.float32, device=b_vec.device) + x_ri = x_ri_work.reshape(-1).contiguous() + else: + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + + for rows_lv in levels: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv,) + _spsv_csr_level_kernel_complex[grid]( + data_ri, + indices, + indptr, + b_ri, + x_ri, + rows_lv, + n_level_rows=n_lv, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + if component_dtype == torch.float16: + return torch.view_as_complex(x_ri_work.contiguous()) + return x + + def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if not all(torch.is_tensor(t) for t in (data, row, col, b)): raise TypeError("data, row, col, b must all be torch.Tensor") @@ -378,7 +589,7 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if b.ndim == 2 and b.shape[0] != n_rows: raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") - if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: + if data.dtype not in (torch.bfloat16, torch.float32, torch.float64): raise TypeError("data dtype must be one of: bfloat16, float32, float64") if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") @@ -395,7 +606,6 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): raise ValueError( f"int64 index value {int(col64.max().item())} exceeds Triton int32 kernel range" ) - _validate_spsv_non_trans_combo(data.dtype, torch.int32, "COO") if row64.numel() > 0: if bool(torch.any(row64 < 0).item()): raise IndexError("row indices must be non-negative") @@ -408,6 +618,7 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if max_col >= n_cols: raise IndexError(f"col indices out of range for n_cols={n_cols}") + _validate_spsv_non_trans_combo(data.dtype, torch.int32, "COO") return ( data.contiguous(), row64, @@ -418,6 +629,44 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): ) +def _csr_transpose(data, indices64, indptr64, n_rows, n_cols): + if data.numel() == 0: + out_data = data + out_indices = torch.empty(0, dtype=torch.int64, device=data.device) + out_indptr = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) + return out_data, out_indices, out_indptr + + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + new_row = indices64 + new_col = row_ids + data_t, indices_t, indptr_t = _coo_to_csr_sorted_unique( + data, new_row, new_col, n_cols, n_rows + ) + return data_t, indices_t, indptr_t + + +def _csr_reverse_rows_cols(data, indices64, indptr64, n_rows): + if data.numel() == 0: + out_data = data + out_indices = torch.empty(0, dtype=torch.int64, device=data.device) + out_indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + return out_data, out_indices, out_indptr + + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + new_row = (n_rows - 1) - row_ids + new_col = (n_rows - 1) - indices64 + data_r, indices_r, indptr_r = _coo_to_csr_sorted_unique( + data, new_row, new_col, n_rows, n_rows + ) + return data_r, indices_r, indptr_r + + def _coo_is_sorted_unique(row64, col64, n_cols): nnz = row64.numel() if nnz <= 1: @@ -532,30 +781,61 @@ def flagsparse_spsv_csr( ): """Sparse triangular solve using Triton level-scheduling kernels. - Primary NON_TRANS support matrix: - - float32 + int32 indices - - float64 + int32 indices + Primary support matrix: + - NON_TRANS: float32/float64/complex32/complex64 with int32/int64 indices + - TRANS: float32/float64/complex32/complex64 with int32 indices + - bfloat16 remains NON_TRANS + int32 """ - data, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( - data, indices, indptr, b, shape, transpose=transpose + trans_mode = _normalize_spsv_transpose_mode(transpose) + data, input_index_dtype, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( + data, indices, indptr, b, shape ) + original_output_dtype = None + data, b, original_output_dtype = _promote_complex32_spsv_inputs(data, b) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") - kernel_indices = indices.to(torch.int32) if indices.dtype != torch.int32 else indices - kernel_indptr = indptr + if trans_mode == "N": + _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "CSR") + lower_eff = lower + kernel_data = data + kernel_indices64 = indices + kernel_indptr64 = indptr + else: + _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") + lower_eff = not lower + kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( + data, indices, indptr, n_rows, n_cols + ) + + kernel_indices = ( + kernel_indices64.to(torch.int32) + if kernel_indices64.dtype != torch.int32 + else kernel_indices64 + ) + kernel_indptr = kernel_indptr64 compute_dtype = data.dtype - data_in = data + data_in = kernel_data b_in = b if data.dtype == torch.bfloat16: compute_dtype = torch.float32 - data_in = data.to(torch.float32) + data_in = kernel_data.to(torch.float32) b_in = b.to(torch.float32) + elif data.dtype == torch.complex64 and trans_mode == "T": + compute_dtype = torch.complex128 + data_in = kernel_data.to(torch.complex128) + b_in = b.to(torch.complex128) elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: # Optional high-precision mode; disabled by default for throughput. compute_dtype = torch.float64 - data_in = data.to(torch.float64) + data_in = kernel_data.to(torch.float64) + b_in = b.to(torch.float64) + elif data.dtype == torch.float32 and trans_mode == "T": + compute_dtype = torch.float64 + data_in = kernel_data.to(torch.float64) b_in = b.to(torch.float64) - levels = _build_spsv_levels(kernel_indptr, kernel_indices, n_rows, lower=lower) + levels = _build_spsv_levels( + kernel_indptr, kernel_indices, n_rows, lower=lower_eff + ) block_nnz_use, max_segments_use = _auto_spsv_launch_config( kernel_indptr, block_nnz=block_nnz, max_segments=max_segments ) @@ -563,44 +843,82 @@ def flagsparse_spsv_csr( torch.cuda.synchronize() t0 = time.perf_counter() if b_in.ndim == 1: - x = _triton_spsv_csr_vector( - data_in, - kernel_indices, - kernel_indptr, - b_in, - n_rows, - lower=lower, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=levels, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, - ) + if torch.is_complex(data_in): + x = _triton_spsv_csr_vector_complex( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + else: + x = _triton_spsv_csr_vector( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) else: cols = [] for j in range(b_in.shape[1]): - cols.append( - _triton_spsv_csr_vector( - data_in, - kernel_indices, - kernel_indptr, - b_in[:, j].contiguous(), - n_rows, - lower=lower, - unit_diagonal=unit_diagonal, - block_nnz=block_nnz, - max_segments=max_segments, - diag_eps=diag_eps, - levels=levels, - block_nnz_use=block_nnz_use, - max_segments_use=max_segments_use, + bj = b_in[:, j].contiguous() + if torch.is_complex(data_in): + cols.append( + _triton_spsv_csr_vector_complex( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + ) + else: + cols.append( + _triton_spsv_csr_vector( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) ) - ) x = torch.stack(cols, dim=1) - if compute_dtype != data.dtype: - x = x.to(data.dtype) + target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype + if x.dtype != target_dtype: + x = _restore_complex32_spsv_output(x, target_dtype) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: @@ -747,4 +1065,4 @@ def flagsparse_spsv_coo( if return_time: return x, elapsed_ms - return x \ No newline at end of file + return x diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index f2cf944..94788f2 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -5,8 +5,105 @@ from tests.pytest.param_shapes import SPSV_N +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse + from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular +except Exception: + cp = None + cpx_sparse = None + cpx_spsolve_triangular = None + pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +COMPLEX32_DTYPE = getattr(torch, "complex32", None) +if COMPLEX32_DTYPE is None: + COMPLEX32_DTYPE = getattr(torch, "chalf", None) + +SUPPORTED_COMPLEX_DTYPES = [] +if COMPLEX32_DTYPE is not None: + SUPPORTED_COMPLEX_DTYPES.append(COMPLEX32_DTYPE) +SUPPORTED_COMPLEX_DTYPES.append(torch.complex64) + +SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] + + +def _dtype_id(dtype): + return str(dtype).replace("torch.", "") + + +def _tol(dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return 5e-3, 5e-3 + if dtype in (torch.float32, torch.complex64): + return 1e-4, 1e-3 + return 1e-10, 1e-8 + + +def _rand_like(dtype, shape, device): + if dtype in (torch.float32, torch.float64): + return torch.randn(shape, dtype=dtype, device=device) + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + pair = torch.randn((*shape, 2), dtype=torch.float16, device=device) * 0.1 + return torch.view_as_complex(pair) + base = torch.float32 + r = torch.randn(shape, dtype=base, device=device) + i = torch.randn(shape, dtype=base, device=device) + return torch.complex(r, i) + + +def _ref_dtype(dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return torch.complex64 + return dtype + + +def _safe_cast_tensor(tensor, dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + real = tensor.real.to(torch.float16) + imag = tensor.imag.to(torch.float16) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return tensor.to(dtype) + + +def _cmp_view(tensor, dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return tensor.to(torch.complex64) + return tensor + + +def _build_lower_triangular(n, dtype, device): + off = _rand_like(dtype, (n, n), device) * 0.02 + A = torch.tril(off) + if torch.is_complex(A): + diag = (torch.rand(n, device=device, dtype=A.real.dtype) + 2.0).to(A.real.dtype) + A = A + torch.diag(torch.complex(diag, torch.zeros_like(diag))) + else: + diag = torch.rand(n, device=device, dtype=A.dtype) + 2.0 + A = A + torch.diag(diag) + return A + + +def _cupy_csr_from_torch(data, indices, indptr, shape): + if cp is None or cpx_sparse is None: + return None + data_ref = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) + idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous())) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous())) + return cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + + +def _cupy_ref_spsv(A_cp, b_t, *, lower, unit_diagonal=False): + if cp is None or cpx_spsolve_triangular is None: + return None + b_ref = b_t.to(torch.complex64) if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE else b_t + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) + x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) + if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE: + return x_t.to(torch.complex64) + return x_t.to(b_t.dtype) @pytest.mark.spsv @@ -17,12 +114,12 @@ ids=["float32", "float64"], ) def test_spsv_csr_lower_matches_dense(n, dtype): + # Keep the original baseline test case untouched in semantics. device = torch.device("cuda") base = torch.tril(torch.randn(n, n, dtype=dtype, device=device)) eye = torch.eye(n, dtype=dtype, device=device) A = base + eye * (float(n) * 0.5 + 2.0) b = torch.randn(n, dtype=dtype, device=device) - # PyTorch 2.x requires B with rank >= 2 for solve_triangular. x_ref = torch.linalg.solve_triangular( A, b.unsqueeze(-1), upper=False ).squeeze(-1) @@ -42,3 +139,96 @@ def test_spsv_csr_lower_matches_dense(n, dtype): rtol = 1e-4 if dtype == torch.float32 else 1e-10 atol = 1e-5 if dtype == torch.float32 else 1e-10 assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_lower_triangular(n, dtype, device) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) +def test_spsv_csr_trans_int32_supported_combos(n, dtype): + device = torch.device("cuda") + A = _build_lower_triangular(n, dtype, device) + b = _rand_like(dtype, (n,), device) + A_ref = A.to(_ref_dtype(dtype)) + b_ref = b.to(_ref_dtype(dtype)) + x_ref = torch.linalg.solve_triangular( + A_ref.transpose(-2, -1), b_ref.unsqueeze(-1), upper=True + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=True, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif(cp is None or cpx_spsolve_triangular is None, reason="CuPy/cuSPARSE required") +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): + device = torch.device("cuda") + A = _build_lower_triangular(n, dtype, device) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_non = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=False + ) + x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) + + x_trans = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=True + ) + x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=False, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 3123774..162dfc8 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -29,6 +29,20 @@ ITERS = 20 DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB +FLOAT16_LIMIT = 65504.0 +COMPLEX32_DTYPE = getattr(torch, "complex32", None) +if COMPLEX32_DTYPE is None: + COMPLEX32_DTYPE = getattr(torch, "chalf", None) + +# CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) +CSR_FULL_VALUE_DTYPES = [ + torch.float32, + torch.float64, +] +if COMPLEX32_DTYPE is not None: + CSR_FULL_VALUE_DTYPES.append(COMPLEX32_DTYPE) +CSR_FULL_VALUE_DTYPES.append(torch.complex64) +CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] def _dtype_name(dtype): @@ -50,11 +64,81 @@ def _fmt_err(v): def _tol_for_dtype(dtype): - if dtype == torch.float32: + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return 5e-3, 5e-3 + if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 return 1e-12, 1e-10 +def _randn_by_dtype(n, dtype, device): + if dtype in (torch.float32, torch.float64): + return torch.randn(n, dtype=dtype, device=device) + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + pair = torch.randn((n, 2), dtype=torch.float16, device=device) * 0.1 + return torch.view_as_complex(pair) + base = torch.float32 + real = torch.randn(n, dtype=base, device=device) + imag = torch.randn(n, dtype=base, device=device) + return torch.complex(real, imag) + + +def _dense_ref_dtype(dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + return torch.complex64 + return dtype + + +def _tensor_from_scalar_values(values, dtype, device): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + real = torch.clamp( + torch.tensor(values, dtype=torch.float32, device=device), + min=-FLOAT16_LIMIT, + max=FLOAT16_LIMIT, + ).to(torch.float16) + imag = torch.zeros_like(real) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return torch.tensor(values, dtype=dtype, device=device) + + +def _safe_cast_tensor(tensor, dtype): + if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: + real = torch.clamp(tensor.real, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) + imag = torch.clamp(tensor.imag, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return tensor.to(dtype) + + +def _cast_real_tensor_to_value_dtype(values, value_dtype): + if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: + real = torch.clamp(values, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) + imag = torch.zeros_like(real) + return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) + return values.to(value_dtype) + + +def _cupy_ref_inputs(data, b): + if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: + return data.to(torch.complex64), b.to(torch.complex64) + return data, b + + +def _compare_view(tensor, value_dtype): + if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: + return tensor.to(torch.complex64) + return tensor + + +def _supported_csr_full_ops(value_dtype, index_dtype): + if value_dtype not in CSR_FULL_VALUE_DTYPES: + return [] + if index_dtype == torch.int32: + return ["NON", "TRANS"] + if index_dtype == torch.int64: + return ["NON"] + return [] + + def _allow_dense_pytorch_ref(shape, dtype): n_rows, n_cols = int(shape[0]), int(shape[1]) elem_bytes = torch.empty((), dtype=dtype).element_size() @@ -68,9 +152,14 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True rows_host = [] cols_host = [] vals_host = [] - base_real_dtype = ( - torch.float32 if value_dtype == torch.float32 else torch.float64 - ) + if value_dtype == torch.float32: + base_real_dtype = torch.float32 + elif value_dtype == torch.float64: + base_real_dtype = torch.float64 + elif COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: + base_real_dtype = torch.float16 + else: + base_real_dtype = torch.float32 for i in range(n): if lower: @@ -100,7 +189,10 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) - vals_t = torch.tensor(vals_host, dtype=base_real_dtype, device=device).to(value_dtype) + vals_t = _cast_real_tensor_to_value_dtype( + torch.tensor(vals_host, dtype=base_real_dtype, device=device), + value_dtype, + ) order = torch.argsort(rows_t * max(1, n) + cols_t) rows_t = rows_t[order] cols_t = cols_t[order] @@ -114,15 +206,16 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True def _csr_to_dense(data, indices, indptr, shape): n_rows, n_cols = shape + coo_data = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data row_ind = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), + torch.arange(n_rows, device=coo_data.device, dtype=torch.int64), indptr[1:] - indptr[:-1], ) coo = torch.sparse_coo_tensor( torch.stack([row_ind, indices.to(torch.int64)]), - data, + coo_data, (n_rows, n_cols), - device=data.device, + device=coo_data.device, ).coalesce() return coo.to_dense() @@ -137,6 +230,33 @@ def _csr_to_coo(data, indices, indptr, shape): return data, row, col +def _csr_transpose(data, indices, indptr, shape): + n_rows, n_cols = int(shape[0]), int(shape[1]) + if data.numel() == 0: + return ( + data, + torch.empty(0, dtype=torch.int64, device=data.device), + torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device), + ) + + row, col = _csr_to_coo(data, indices, indptr, shape)[1:] + row_t = col + col_t = row + key = row_t * max(1, n_rows) + col_t + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + + row_t = row_t[order] + col_t = col_t[order] + data_t = data[order] + nnz_per_row = torch.bincount(row_t, minlength=n_cols) + indptr_t = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) + indptr_t[1:] = torch.cumsum(nnz_per_row, dim=0) + return data_t, col_t.to(torch.int64), indptr_t + + def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -214,7 +334,7 @@ def _accum(r, c, v): cols_s.append(c) vals_s.append(row[c]) indptr_list.append(len(cols_s)) - data = torch.tensor(vals_s, dtype=dtype, device=device) + data = _tensor_from_scalar_values(vals_s, dtype, device) indices = torch.tensor(cols_s, dtype=torch.int64, device=device) indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) return data, indices, indptr, (n_rows, n_cols) @@ -239,6 +359,46 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): return data_c, row_c, col_c +def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): + if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: + data_ref = data.to(torch.complex64) + x_ref = x_true.to(torch.complex64) + if op_mode == "NON": + b_ref, _ = fs.flagsparse_spmv_csr( + data_ref, indices, indptr, x_ref, shape, return_time=True + ) + return _safe_cast_tensor(b_ref, x_true.dtype) + if op_mode == "TRANS": + data_t, indices_t, indptr_t = _csr_transpose(data_ref, indices, indptr, shape) + b_ref, _ = fs.flagsparse_spmv_csr( + data_t, + indices_t.to(indices.dtype), + indptr_t.to(indptr.dtype), + x_ref, + (shape[1], shape[0]), + return_time=True, + ) + return _safe_cast_tensor(b_ref, x_true.dtype) + raise ValueError("op_mode must be 'NON' or 'TRANS'") + if op_mode == "NON": + b, _ = fs.flagsparse_spmv_csr( + data, indices, indptr, x_true, shape, return_time=True + ) + return b + if op_mode == "TRANS": + data_t, indices_t, indptr_t = _csr_transpose(data, indices, indptr, shape) + b, _ = fs.flagsparse_spmv_csr( + data_t, + indices_t.to(indices.dtype), + indptr_t.to(indptr.dtype), + x_true, + (shape[1], shape[0]), + return_time=True, + ) + return b + raise ValueError("op_mode must be 'NON' or 'TRANS'") + + def _cupy_spsolve_lower_csr_or_coo( fmt, data, @@ -292,12 +452,62 @@ def _cupy_spsolve_lower_csr_or_coo( t1.record() t1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(t0, t1) / iters - x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()).to(b.dtype) + x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) + if COMPLEX32_DTYPE is not None and b.dtype == COMPLEX32_DTYPE: + x_cu_t = x_cu_t.to(torch.complex64) + else: + x_cu_t = x_cu_t.to(b.dtype) return cupy_ms, x_cu_t except Exception: return None, None +def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): + if ( + cp is None + or cpx_sparse is None + or cpx_spsolve_triangular is None + ): + return None, None + try: + data_ref, b_ref = _cupy_ref_inputs(data, b) + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) + idx_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous()) + ) + ptr_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous()) + ) + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) + A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + if op_mode == "TRANS": + A_eff = A_cp.transpose().tocsr() + lower_eff = False + else: + A_eff = A_cp + lower_eff = True + + for _ in range(WARMUP): + _ = cpx_spsolve_triangular( + A_eff, b_cp, lower=lower_eff, unit_diagonal=False + ) + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(ITERS): + x_cp = cpx_spsolve_triangular( + A_eff, b_cp, lower=lower_eff, unit_diagonal=False + ) + c1.record() + c1.synchronize() + ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()).to(b.dtype) + return ms, x_t + except Exception: + return None, None + + def run_spsv_synthetic_all(): if not torch.cuda.is_available(): print("CUDA is not available. Please run on a GPU-enabled system.") @@ -524,22 +734,25 @@ def _finalize_csv_row( A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr, shape ) - A_ref = A_dense - b_ref = b + ref_dtype = _dense_ref_dtype(value_dtype) + A_ref = A_dense.to(ref_dtype) + b_ref = b.to(ref_dtype) e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() x_ref = torch.linalg.solve(A_ref, b_ref.unsqueeze(1)).squeeze(1) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() torch.cuda.synchronize() pytorch_ms = e0.elapsed_time(e1) err_pt = ( - float(torch.max(torch.abs(x - x_ref)).item()) + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) if n_rows > 0 else 0.0 ) - ok_pt = torch.allclose(x, x_ref, atol=atol, rtol=rtol) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) except RuntimeError as e: if "out of memory" in str(e).lower(): pt_skip_reason = "PyTorch dense ref OOM; skipped" @@ -558,17 +771,18 @@ def _finalize_csv_row( cp is not None and cpx_sparse is not None and cpx_spsolve_triangular is not None - and value_dtype in (torch.float32, torch.float64) ): try: - b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b.contiguous())) + data_ref, b_ref = _cupy_ref_inputs(data, b) + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) if ( cupy_coo_data is not None and cupy_coo_row is not None and cupy_coo_col is not None ): + coo_data_ref, _ = _cupy_ref_inputs(cupy_coo_data, b) data_cp = cp.from_dlpack( - torch.utils.dlpack.to_dlpack(cupy_coo_data.contiguous()) + torch.utils.dlpack.to_dlpack(coo_data_ref.contiguous()) ) row_cp = cp.from_dlpack( torch.utils.dlpack.to_dlpack( @@ -585,7 +799,7 @@ def _finalize_csv_row( ) else: data_cp = cp.from_dlpack( - torch.utils.dlpack.to_dlpack(data.contiguous()) + torch.utils.dlpack.to_dlpack(data_ref.contiguous()) ) idx_cp = cp.from_dlpack( torch.utils.dlpack.to_dlpack( @@ -613,13 +827,15 @@ def _finalize_csv_row( c1.record() c1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS - x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()).to(x.dtype) + x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) + x_cmp = _compare_view(x, value_dtype) + x_cu_cmp = _compare_view(x_cu_t, value_dtype) err_cu = ( - float(torch.max(torch.abs(x - x_cu_t)).item()) + float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) if n_rows > 0 else 0.0 ) - ok_cu = torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) + ok_cu = torch.allclose(x_cmp, x_cu_cmp, atol=atol, rtol=rtol) except Exception: cupy_ms = None err_cu = None @@ -649,6 +865,243 @@ def _finalize_csv_row( return row, pt_skip_reason +def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device + ) + indices = indices.to(index_dtype) + indptr = indptr.to(index_dtype) + n_rows, n_cols = shape + x_true = _randn_by_dtype(n_rows, value_dtype, device) + b = _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode) + x, t_ms = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=True, + transpose=(op_mode == "TRANS"), + return_time=True, + ) + return _finalize_csv_row_csr_full( + path, + value_dtype, + index_dtype, + op_mode, + data, + indices, + indptr, + shape, + x, + t_ms, + b, + n_rows, + n_cols, + ) + + +def _finalize_csv_row_csr_full( + path, + value_dtype, + index_dtype, + op_mode, + data, + indices, + indptr, + shape, + x, + t_ms, + b, + n_rows, + n_cols, +): + atol, rtol = _tol_for_dtype(value_dtype) + + pytorch_ms = None + err_pt = None + ok_pt = False + pt_skip_reason = None + if _allow_dense_pytorch_ref(shape, value_dtype): + try: + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + A_ref = A_dense.transpose(0, 1) if op_mode == "TRANS" else A_dense + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + torch.cuda.synchronize() + e0.record() + x_ref = torch.linalg.solve(A_ref, b.to(A_ref.dtype).unsqueeze(1)).squeeze(1) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = ( + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pt_skip_reason = "PyTorch dense ref OOM; skipped" + else: + raise + else: + pt_skip_reason = ( + f"PyTorch dense ref skipped (> {DENSE_REF_MAX_BYTES // (1024**3)} GiB dense matrix)" + ) + + cupy_ms = None + err_cu = None + ok_cu = False + x_cu_t = None + cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( + data, indices, indptr, shape, b, op_mode + ) + if x_cu_t is not None: + x_cmp = _compare_view(x, value_dtype) + x_cu_cmp = _compare_view(x_cu_t, value_dtype) + err_cu = ( + float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_cu = torch.allclose(x_cmp, x_cu_cmp, atol=atol, rtol=rtol) + + status = "PASS" if (ok_pt or ok_cu) else "FAIL" + if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): + status = "REF_FAIL" + + row = { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": int(data.numel()), + "triton_ms": t_ms, + "pytorch_ms": pytorch_ms, + "cusparse_ms": cupy_ms, + "csc_ms": None, + "status": status, + "err_pt": err_pt, + "err_cu": err_cu, + } + return row, pt_skip_reason + + +def run_all_supported_spsv_csr_csv(mtx_paths, csv_path): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + rows_out = [] + for value_dtype in CSR_FULL_VALUE_DTYPES: + for index_dtype in CSR_FULL_INDEX_DTYPES: + op_modes = _supported_csr_full_ops(value_dtype, index_dtype) + for op_mode in op_modes: + print("=" * 150) + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | opA={op_mode}" + ) + print( + "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." + ) + print( + "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " + "PASS if either error within tolerance." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 150) + for path in mtx_paths: + try: + row, pt_skip = _run_one_csv_row_csr_full( + path, value_dtype, index_dtype, op_mode, device + ) + rows_out.append(row) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + n_rows, n_cols = row["n_rows"], row["n_cols"] + nnz = row["nnz"] + t_ms = row["triton_ms"] + cupy_ms = row["cusparse_ms"] + pytorch_ms = row["pytorch_ms"] + err_pt, err_cu = row["err_pt"], row["err_cu"] + status = row["status"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " + f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " + f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " + f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + ) + if pt_skip: + print(f" NOTE: {pt_skip}") + except Exception as e: + err_msg = str(e) + status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" + rows_out.append( + { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "triton_ms": None, + "pytorch_ms": None, + "cusparse_ms": None, + "csc_ms": None, + "status": status, + "err_pt": None, + "err_cu": None, + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " + f"{'N/A':>7} {'N/A':>7} " + f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" {status}: {e}") + print("-" * 150) + fieldnames = [ + "matrix", + "value_dtype", + "index_dtype", + "opA", + "n_rows", + "n_cols", + "nnz", + "triton_ms", + "pytorch_ms", + "cusparse_ms", + "csc_ms", + "status", + "err_pt", + "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto"): if not torch.cuda.is_available(): print("CUDA is not available.") @@ -785,7 +1238,7 @@ def main(): type=str, default=None, metavar="FILE", - help="Run all dtypes/index dtypes on .mtx (CSR SpSV) and export CSV", + help="Run full supported CSR SpSV combinations (dtype/index/opA) on .mtx and export CSV", ) parser.add_argument( "--csv-coo", @@ -819,7 +1272,7 @@ def main(): if not paths: print("No .mtx files found for --csv-csr") return - run_all_dtypes_spsv_csv(paths, args.csv_csr, use_coo=False) + run_all_supported_spsv_csr_csv(paths, args.csv_csr) return if args.csv_coo: if not paths: @@ -836,4 +1289,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 0e76bd8d1e404cd974487fd1519a86aff0dbc349 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Thu, 16 Apr 2026 18:34:21 +0800 Subject: [PATCH 02/22] complex128 --- .gitignore | 1 + benchmark/benchmark_spsv.py | 14 + src/flagsparse/sparse_operations/_common.py | 43 +- .../sparse_operations/gather_scatter.py | 55 +- src/flagsparse/sparse_operations/spsv.py | 199 ++++++- tests/pytest/test_gather_scatter_accuracy.py | 119 +--- tests/pytest/test_spsv_csr_accuracy.py | 245 +++++++- tests/test_gather.py | 14 +- tests/test_spsv.py | 545 ++++++++++++++---- 9 files changed, 901 insertions(+), 334 deletions(-) create mode 100644 benchmark/benchmark_spsv.py diff --git a/.gitignore b/.gitignore index 4ca5df5..e44113e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__/ .pytest_cache/ .coverage htmlcov/ +.DS_Store diff --git a/benchmark/benchmark_spsv.py b/benchmark/benchmark_spsv.py new file mode 100644 index 0000000..3fa258f --- /dev/null +++ b/benchmark/benchmark_spsv.py @@ -0,0 +1,14 @@ +"""Run SpSV benchmark. From project root: python benchmark/benchmark_spsv.py [--synthetic | --csv-csr out.csv].""" + +import sys +from pathlib import Path + +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) + +from tests.test_spsv import main + + +if __name__ == "__main__": + main() diff --git a/src/flagsparse/sparse_operations/_common.py b/src/flagsparse/sparse_operations/_common.py index 7e89526..6f7725d 100644 --- a/src/flagsparse/sparse_operations/_common.py +++ b/src/flagsparse/sparse_operations/_common.py @@ -18,19 +18,14 @@ cp = None cpx_sparse = None -_TORCH_COMPLEX32_DTYPE = getattr(torch, "complex32", None) -if _TORCH_COMPLEX32_DTYPE is None: - _TORCH_COMPLEX32_DTYPE = getattr(torch, "chalf", None) - _SUPPORTED_VALUE_DTYPES = [ torch.float16, torch.bfloat16, torch.float32, torch.float64, + torch.complex64, + torch.complex128, ] -if _TORCH_COMPLEX32_DTYPE is not None: - _SUPPORTED_VALUE_DTYPES.append(_TORCH_COMPLEX32_DTYPE) -_SUPPORTED_VALUE_DTYPES.extend([torch.complex64, torch.complex128]) SUPPORTED_VALUE_DTYPES = tuple(_SUPPORTED_VALUE_DTYPES) SUPPORTED_INDEX_DTYPES = (torch.int32, torch.int64) _INDEX_LIMIT_INT32 = 2**31 - 1 @@ -40,8 +35,6 @@ "SUPPORTED_VALUE_DTYPES", "SUPPORTED_INDEX_DTYPES", "_INDEX_LIMIT_INT32", - "_torch_complex32_dtype", - "_is_complex32_dtype", "_is_complex_dtype", "_resolve_scatter_value_dtype", "_component_dtype_for_complex", @@ -68,17 +61,8 @@ "tl", ) - -def _torch_complex32_dtype(): - return _TORCH_COMPLEX32_DTYPE - - -def _is_complex32_dtype(value_dtype): - return _TORCH_COMPLEX32_DTYPE is not None and value_dtype == _TORCH_COMPLEX32_DTYPE - - def _is_complex_dtype(value_dtype): - return _is_complex32_dtype(value_dtype) or value_dtype in (torch.complex64, torch.complex128) + return value_dtype in (torch.complex64, torch.complex128) def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): @@ -95,24 +79,13 @@ def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): "complex64": torch.complex64, "complex128": torch.complex128, } - if token == "complex32": - if _TORCH_COMPLEX32_DTYPE is not None: - return _TORCH_COMPLEX32_DTYPE, False, None - if dtype_policy == "strict": - raise TypeError("complex32 is unavailable in this torch build") - return torch.complex64, True, "complex32 is unavailable; fallback to complex64" if token not in mapping: raise TypeError(f"Unsupported dtype token: {value_dtype}") value_dtype = mapping[token] - if _is_complex32_dtype(value_dtype): - # If complex32 exists in torch dtype table, keep native path. - return value_dtype, False, None return value_dtype, False, None def _component_dtype_for_complex(value_dtype): - if _is_complex32_dtype(value_dtype): - return torch.float16 if value_dtype == torch.complex64: return torch.float32 if value_dtype == torch.complex128: @@ -123,8 +96,6 @@ def _component_dtype_for_complex(value_dtype): def _tolerance_for_dtype(value_dtype): if value_dtype == torch.float16: return 2e-3, 2e-3 - if _is_complex32_dtype(value_dtype): - return 5e-3, 5e-3 if value_dtype == torch.bfloat16: return 1e-1, 1e-1 if value_dtype in (torch.float32, torch.complex64): @@ -155,9 +126,6 @@ def _cupy_dtype_from_torch(torch_dtype): torch.int32: cp.int32, torch.int64: cp.int64, } - if _TORCH_COMPLEX32_DTYPE is not None: - # CuPy has no native complex32 sparse path; use complex64 for baseline parity. - mapping[_TORCH_COMPLEX32_DTYPE] = cp.complex64 if torch_dtype not in mapping: raise TypeError(f"Unsupported dtype conversion to CuPy: {torch_dtype}") return mapping[torch_dtype] @@ -193,8 +161,6 @@ def _to_backend_like(torch_tensor, ref_obj): def _cusparse_baseline_skip_reason(value_dtype): if value_dtype == torch.bfloat16: return "bfloat16 is not supported by the cuSPARSE baseline path; skipped" - if _is_complex32_dtype(value_dtype): - return "complex32 is not supported by the cuSPARSE baseline path; skipped" if cp is None and value_dtype == torch.float16: return "float16 is not supported by torch sparse fallback when CuPy is unavailable; skipped" return None @@ -203,9 +169,6 @@ def _cusparse_baseline_skip_reason(value_dtype): def _build_random_dense(dense_size, value_dtype, device): if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): return torch.randn(dense_size, dtype=value_dtype, device=device) - if _is_complex32_dtype(value_dtype): - stacked = torch.randn((dense_size, 2), dtype=torch.float16, device=device) - return torch.view_as_complex(stacked) if _is_complex_dtype(value_dtype): component_dtype = _component_dtype_for_complex(value_dtype) real = torch.randn(dense_size, dtype=component_dtype, device=device) diff --git a/src/flagsparse/sparse_operations/gather_scatter.py b/src/flagsparse/sparse_operations/gather_scatter.py index 6a02d8e..67066b0 100644 --- a/src/flagsparse/sparse_operations/gather_scatter.py +++ b/src/flagsparse/sparse_operations/gather_scatter.py @@ -226,6 +226,10 @@ def _launch_triton_scatter_kernel( ) +def _validate_gather_value_dtype(dense_vector, op_name): + return None + + def _cusparse_spmv(selector_matrix, dense_vector): if cp is not None and cpx_sparse is not None and isinstance(selector_matrix, cpx_sparse.spmatrix): if torch.is_tensor(dense_vector): @@ -330,6 +334,7 @@ def flagsparse_gather(a, indices, out=None, mode="raise", block_size=1024, retur dense_vector, dense_backend = _to_torch_tensor(a, "a") indices_tensor, _ = _to_torch_tensor(indices, "indices") dense_vector, indices_tensor, kernel_indices = _prepare_inputs(dense_vector, indices_tensor) + _validate_gather_value_dtype(dense_vector, "flagsparse_gather") torch.cuda.synchronize() start_time = time.perf_counter() @@ -479,6 +484,7 @@ def pytorch_index_scatter( def cusparse_spmv_gather(dense_vector, indices, selector_matrix=None): """Equivalent gather baseline via cuSPARSE-backed COO SpMV.""" dense_vector, indices, _ = _prepare_inputs(dense_vector, indices) + _validate_gather_value_dtype(dense_vector, "cusparse_spmv_gather") skip_reason = _cusparse_baseline_skip_reason(dense_vector.dtype) if skip_reason: raise RuntimeError(skip_reason) @@ -557,21 +563,14 @@ def _cupy_gather_detect_layout(dense_vector): return "scalar16" if dense_vector.ndim == 1 and dense_vector.dtype == torch.complex64: return "complex64" - if ( - dense_vector.ndim == 2 - and dense_vector.shape[1] == 2 - and dense_vector.dtype == torch.float16 - ): - # complex32 alignment with half2 storage. - return "complex16_pair" raise TypeError( "Unsupported gather input format. Expected one of: " - "1D float16/bfloat16, 1D complex64, or 2D (N,2) float16." + "1D float16/bfloat16 or 1D complex64." ) def _cupy_gather_dense_size(dense_vector, layout): - if layout in ("scalar16", "complex64", "complex16_pair"): + if layout in ("scalar16", "complex64"): return int(dense_vector.shape[0]) raise RuntimeError(f"Unknown gather layout: {layout}") @@ -601,7 +600,7 @@ def _cupy_gather_validate_inputs(dense_vector, indices): def _cupy_gather_validate_combo(dense_vector, indices, layout): # Keep only the required extra gather combos: - # Half+Int64, Bfloat16+Int32/Int64, Complex32+Int32/Int64, Complex64+Int32/Int64 + # Half+Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64 if layout == "scalar16": if dense_vector.dtype == torch.float16: if indices.dtype != torch.int64: @@ -611,11 +610,6 @@ def _cupy_gather_validate_combo(dense_vector, indices, layout): return raise TypeError("scalar16 gather_cupy supports only float16/bfloat16") - if layout == "complex16_pair": - if dense_vector.dtype != torch.float16: - raise TypeError("complex16_pair gather_cupy supports only float16 pairs") - return - if layout == "complex64": return @@ -625,8 +619,6 @@ def _cupy_gather_validate_combo(dense_vector, indices, layout): def _cupy_gather_layout_raw_kind(layout): if layout == "scalar16": return 16 - if layout == "complex16_pair": - return 32 if layout == "complex64": return 64 raise RuntimeError(f"Unknown gather layout: {layout}") @@ -675,9 +667,6 @@ def _cupy_gather_dense_to_raw_torch(dense_t, layout): return dense_t.reshape(-1).view(torch.uint16) if layout == "complex64": return dense_t.reshape(-1).view(torch.uint64) - if layout == "complex16_pair": - lanes_u16 = dense_t.reshape(-1).view(torch.uint16) - return lanes_u16.reshape(-1, 2).view(torch.uint32).reshape(-1) raise RuntimeError(f"Unknown gather layout: {layout}") @@ -686,22 +675,17 @@ def _cupy_gather_raw_to_dense_torch(out_raw_t, layout, dense_t_dtype): return out_raw_t.view(dense_t_dtype).reshape(-1) if layout == "complex64": return out_raw_t.view(torch.complex64).reshape(-1) - if layout == "complex16_pair": - lanes_u16 = out_raw_t.view(torch.uint16).reshape(-1, 2) - return lanes_u16.view(dense_t_dtype).reshape(-1, 2) raise RuntimeError(f"Unknown gather layout: {layout}") def _cupy_gather_empty(layout, dense_dtype, device): if layout in ("scalar16", "complex64"): return torch.empty(0, dtype=dense_dtype, device=device) - if layout == "complex16_pair": - return torch.empty((0, 2), dtype=dense_dtype, device=device) raise RuntimeError(f"Unknown gather layout: {layout}") def _cupy_gather_selector_dtype(layout, dense_dtype): - if layout in ("scalar16", "complex16_pair"): + if layout == "scalar16": return dense_dtype if layout == "complex64": return torch.complex64 @@ -709,21 +693,12 @@ def _cupy_gather_selector_dtype(layout, dense_dtype): def _cupy_gather_prepare_dense(dense_vector, indices): - runtime_dense = dense_vector - restore_mode = None - native_complex32 = _torch_complex32_dtype() - if native_complex32 is not None and dense_vector.ndim == 1 and dense_vector.dtype == native_complex32: - runtime_dense = torch.view_as_real(dense_vector).contiguous() - restore_mode = "native_complex32" - - layout, dense_size = _cupy_gather_validate_inputs(runtime_dense, indices) - _cupy_gather_validate_combo(runtime_dense, indices, layout) - return runtime_dense, layout, dense_size, restore_mode + layout, dense_size = _cupy_gather_validate_inputs(dense_vector, indices) + _cupy_gather_validate_combo(dense_vector, indices, layout) + return dense_vector, layout, dense_size, None def _cupy_gather_restore_output(gathered_t, restore_mode): - if restore_mode == "native_complex32": - return torch.view_as_complex(gathered_t.contiguous()) return gathered_t @@ -880,10 +855,6 @@ def cusparse_spmv_gather_cupy(dense_vector, indices, selector_matrix=None): start_time = time.perf_counter() if layout in ("scalar16", "complex64"): gathered_t = _cusparse_spmv(selector_matrix, runtime_dense_t) - elif layout == "complex16_pair": - gathered_real = _cusparse_spmv(selector_matrix, runtime_dense_t[:, 0]) - gathered_imag = _cusparse_spmv(selector_matrix, runtime_dense_t[:, 1]) - gathered_t = torch.stack([gathered_real, gathered_imag], dim=1) else: raise RuntimeError(f"Unknown gather layout: {layout}") torch.cuda.synchronize() diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 2976ad1..dd867bc 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -2,6 +2,7 @@ from ._common import * +from collections import OrderedDict import os import time import triton @@ -11,31 +12,34 @@ torch.bfloat16, torch.float32, torch.float64, - *((_torch_complex32_dtype(),) if _torch_complex32_dtype() is not None else ()), torch.complex64, + torch.complex128, + ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) SPSV_NON_TRANS_PRIMARY_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), - *(((_torch_complex32_dtype(), torch.int32),) if _torch_complex32_dtype() is not None else ()), (torch.complex64, torch.int32), + (torch.complex128, torch.int32), ) SPSV_NON_TRANS_EXTENDED_COMBOS = ( (torch.float32, torch.int64), (torch.float64, torch.int64), - *(((_torch_complex32_dtype(), torch.int64),) if _torch_complex32_dtype() is not None else ()), (torch.complex64, torch.int64), + (torch.complex128, torch.int64), ) SPSV_TRANS_PRIMARY_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), - *(((_torch_complex32_dtype(), torch.int32),) if _torch_complex32_dtype() is not None else ()), (torch.complex64, torch.int32), + (torch.complex128, torch.int32), ) SPSV_PROMOTE_FP32_TO_FP64 = str( os.environ.get("FLAGSPARSE_SPSV_PROMOTE_FP32_TO_FP64", "0") ).lower() in ("1", "true", "yes", "on") +_SPSV_CSR_PREPROCESS_CACHE = OrderedDict() +_SPSV_CSR_PREPROCESS_CACHE_SIZE = 8 def _csr_to_dense(data, indices, indptr, shape): """Convert CSR (torch CUDA tensors) to dense matrix on the same device.""" @@ -69,7 +73,11 @@ def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): raise TypeError( f"{fmt_name} SpSV currently supports NON_TRANS combinations: " "(float32, int32/int64), (float64, int32/int64), " +<<<<<<< HEAD "(complex32, int32/int64), (complex64, int32/int64), (bfloat16, int32)" +======= + "(complex64, int32/int64), (complex128, int32/int64), (bfloat16, int32)" +>>>>>>> 5a83e0f (test) ) @@ -78,7 +86,11 @@ def _validate_spsv_trans_combo(data_dtype, index_dtype, fmt_name): return raise TypeError( f"{fmt_name} SpSV currently supports TRANS combinations with int32 indices only: " +<<<<<<< HEAD "(float32, int32), (float64, int32), (complex32, int32), (complex64, int32)" +======= + "(float32, int32), (float64, int32), (complex64, int32), (complex128, int32)" +>>>>>>> 5a83e0f (test) ) @@ -118,7 +130,11 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( +<<<<<<< HEAD "data dtype must be one of: bfloat16, float32, float64, complex32, complex64" +======= + "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" +>>>>>>> 5a83e0f (test) ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -158,6 +174,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): ) +<<<<<<< HEAD def _promote_complex32_spsv_inputs(data, b): if _is_complex32_dtype(data.dtype): return data.to(torch.complex64), b.to(torch.complex64), data.dtype @@ -173,6 +190,62 @@ def _restore_complex32_spsv_output(x, target_dtype): return x.to(target_dtype) +======= +def _prepare_spsv_working_inputs(data, b): + return data, b, None + + +def _restore_spsv_output(x, target_dtype): + return x.to(target_dtype) + + +def _spsv_diag_eps_for_dtype(value_dtype): + return 1e-12 if value_dtype in (torch.float64, torch.complex128) else 1e-6 + + +def _tensor_cache_token(tensor): + try: + storage_ptr = int(tensor.untyped_storage().data_ptr()) + except Exception: + storage_ptr = 0 + return ( + str(tensor.device), + str(tensor.dtype), + tuple(int(v) for v in tensor.shape), + int(tensor.numel()), + storage_ptr, + int(getattr(tensor, "_version", 0)), + ) + + +def _spsv_cache_get(cache, key): + value = cache.get(key) + if value is not None: + cache.move_to_end(key) + return value + + +def _spsv_cache_put(cache, key, value, max_entries): + cache[key] = value + cache.move_to_end(key) + while len(cache) > max_entries: + cache.popitem(last=False) + + +def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode): + return ( + "csr_preprocess", + trans_mode, + bool(lower), + int(shape[0]), + int(shape[1]), + _tensor_cache_token(data), + _tensor_cache_token(indices), + _tensor_cache_token(indptr), + ) + + +>>>>>>> 5a83e0f (test) @triton.jit def _spsv_csr_level_kernel( data_ptr, @@ -536,7 +609,19 @@ def _triton_spsv_csr_vector_complex( indptr, block_nnz=block_nnz, max_segments=max_segments ) +<<<<<<< HEAD data_ri = torch.view_as_real(data.contiguous()).reshape(-1).contiguous() +======= + # Some PyTorch builds return CSR values with a non-strided layout wrapper. + # Materialize a plain 1D strided buffer before splitting into real/imag parts. + if data.layout != torch.strided: + data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) + data_strided.copy_(data) + else: + data_strided = data.contiguous() + + data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() +>>>>>>> 5a83e0f (test) b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -589,17 +674,27 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if b.ndim == 2 and b.shape[0] != n_rows: raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") +<<<<<<< HEAD if data.dtype not in (torch.bfloat16, torch.float32, torch.float64): raise TypeError("data dtype must be one of: bfloat16, float32, float64") +======= + if data.dtype not in ( + torch.bfloat16, + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ): + raise TypeError( + "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" + ) +>>>>>>> 5a83e0f (test) if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") if row.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("row dtype must be torch.int32 or torch.int64") if col.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("col dtype must be torch.int32 or torch.int64") - if transpose: - raise NotImplementedError("transpose=True is not implemented in Triton SpSV yet") - row64 = row.to(torch.int64).contiguous() col64 = col.to(torch.int64).contiguous() if col64.numel() > 0 and int(col64.max().item()) > _INDEX_LIMIT_INT32: @@ -782,20 +877,36 @@ def flagsparse_spsv_csr( """Sparse triangular solve using Triton level-scheduling kernels. Primary support matrix: +<<<<<<< HEAD - NON_TRANS: float32/float64/complex32/complex64 with int32/int64 indices - TRANS: float32/float64/complex32/complex64 with int32 indices - bfloat16 remains NON_TRANS + int32 """ +======= + - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices + - TRANS: float32/float64/complex64/complex128 with int32 indices + - bfloat16 remains NON_TRANS + int32 + """ + input_data = data + input_indices = indices + input_indptr = indptr +>>>>>>> 5a83e0f (test) trans_mode = _normalize_spsv_transpose_mode(transpose) data, input_index_dtype, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( data, indices, indptr, b, shape ) original_output_dtype = None +<<<<<<< HEAD data, b, original_output_dtype = _promote_complex32_spsv_inputs(data, b) +======= + rev_perm = None + data, b, original_output_dtype = _prepare_spsv_working_inputs(data, b) +>>>>>>> 5a83e0f (test) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") if trans_mode == "N": _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "CSR") +<<<<<<< HEAD lower_eff = lower kernel_data = data kernel_indices64 = indices @@ -806,6 +917,46 @@ def flagsparse_spsv_csr( kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( data, indices, indptr, n_rows, n_cols ) +======= + else: + _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") + + preprocess_key = _csr_preprocess_cache_key( + input_data, input_indices, input_indptr, (n_rows, n_cols), lower, trans_mode + ) + cached = _spsv_cache_get(_SPSV_CSR_PREPROCESS_CACHE, preprocess_key) + if cached is None: + if trans_mode == "N": + lower_eff = lower + kernel_data = data + kernel_indices64 = indices + kernel_indptr64 = indptr + rev_perm = None + else: + lower_eff = not lower + kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( + data, indices, indptr, n_rows, n_cols + ) + rev_perm = None + levels = _build_spsv_levels( + kernel_indptr64, kernel_indices64, n_rows, lower=lower_eff + ) + cached = ( + kernel_data, + kernel_indices64, + kernel_indptr64, + rev_perm, + lower_eff, + levels, + ) + _spsv_cache_put( + _SPSV_CSR_PREPROCESS_CACHE, + preprocess_key, + cached, + _SPSV_CSR_PREPROCESS_CACHE_SIZE, + ) + kernel_data, kernel_indices64, kernel_indptr64, rev_perm, lower_eff, levels = cached +>>>>>>> 5a83e0f (test) kernel_indices = ( kernel_indices64.to(torch.int32) @@ -828,6 +979,7 @@ def flagsparse_spsv_csr( # Optional high-precision mode; disabled by default for throughput. compute_dtype = torch.float64 data_in = kernel_data.to(torch.float64) +<<<<<<< HEAD b_in = b.to(torch.float64) elif data.dtype == torch.float32 and trans_mode == "T": compute_dtype = torch.float64 @@ -836,10 +988,17 @@ def flagsparse_spsv_csr( levels = _build_spsv_levels( kernel_indptr, kernel_indices, n_rows, lower=lower_eff ) +======= + b_in = b.to(torch.float64) + elif data.dtype == torch.float32 and trans_mode == "T": + compute_dtype = torch.float64 + data_in = kernel_data.to(torch.float64) + b_in = b.to(torch.float64) +>>>>>>> 5a83e0f (test) block_nnz_use, max_segments_use = _auto_spsv_launch_config( kernel_indptr, block_nnz=block_nnz, max_segments=max_segments ) - diag_eps = 1e-12 if compute_dtype == torch.float64 else 1e-6 + diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) torch.cuda.synchronize() t0 = time.perf_counter() if b_in.ndim == 1: @@ -918,7 +1077,11 @@ def flagsparse_spsv_csr( x = torch.stack(cols, dim=1) target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype if x.dtype != target_dtype: +<<<<<<< HEAD x = _restore_complex32_spsv_output(x, target_dtype) +======= + x = _restore_spsv_output(x, target_dtype) +>>>>>>> 5a83e0f (test) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: @@ -950,11 +1113,11 @@ def flagsparse_spsv_coo( """COO SpSV with dual mode: - direct: use COO level kernel directly (requires sorted+unique COO) - csr: convert COO -> CSR (sorted+deduplicated) then call flagsparse_spsv_csr - - auto: pick direct when sorted+unique, otherwise csr + - auto: pick direct when sorted+unique and supported, otherwise csr - Primary NON_TRANS support matrix: - - float32 + int32 indices - - float64 + int32 indices + Notes: + - direct mode currently supports only non-transposed real-valued inputs + - complex dtypes and transpose=True always route through the CSR implementation """ data, row64, col64, b, n_rows, n_cols = _prepare_spsv_coo_inputs( data, row, col, b, shape, transpose=transpose @@ -967,7 +1130,13 @@ def flagsparse_spsv_coo( raise ValueError("coo_mode must be one of: 'auto', 'direct', 'csr'") sorted_unique = _coo_is_sorted_unique(row64, col64, n_cols) - use_direct = mode == "direct" or (mode == "auto" and sorted_unique) + direct_supported = (not transpose) and (not torch.is_complex(data)) + use_direct = direct_supported and (mode == "direct" or (mode == "auto" and sorted_unique)) + if mode == "direct" and not direct_supported: + raise ValueError( + "coo_mode='direct' supports only non-transposed real-valued inputs; " + "use coo_mode='csr' or 'auto' for transpose or complex dtypes" + ) if mode == "direct" and not sorted_unique: raise ValueError( "coo_mode='direct' requires COO sorted by (row, col) with no duplicate coordinates; " @@ -978,6 +1147,8 @@ def flagsparse_spsv_coo( data_csr, indices_csr, indptr_csr = _coo_to_csr_sorted_unique( data, row64, col64, n_rows, n_cols ) + if transpose: + indices_csr = indices_csr.to(torch.int32) return flagsparse_spsv_csr( data_csr, indices_csr, @@ -1011,7 +1182,7 @@ def flagsparse_spsv_coo( block_nnz_use, max_segments_use = _auto_spsv_launch_config( row_ptr, block_nnz=block_nnz, max_segments=max_segments ) - diag_eps = 1e-12 if compute_dtype == torch.float64 else 1e-6 + diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) torch.cuda.synchronize() t0 = time.perf_counter() diff --git a/tests/pytest/test_gather_scatter_accuracy.py b/tests/pytest/test_gather_scatter_accuracy.py index 3563ef1..e67c648 100644 --- a/tests/pytest/test_gather_scatter_accuracy.py +++ b/tests/pytest/test_gather_scatter_accuracy.py @@ -19,17 +19,8 @@ RESET_OUTPUT_CASES = [True, False] RESET_OUTPUT_IDS = ["reset", "inplace"] - -def _complex32_dtype(): - dtype = getattr(torch, "complex32", None) - if dtype is None: - dtype = getattr(torch, "chalf", None) - return dtype - - def _scatter_dtype_cases(): cases = [(str(dtype).replace("torch.", ""), dtype) for dtype in FLOAT_DTYPES] - cases.append(("complex32", _complex32_dtype())) cases.append(("complex64", torch.complex64)) cases.append(("complex128", torch.complex128)) return cases @@ -59,9 +50,6 @@ def _build_random_values(size, dtype, device): real = torch.randn(size, dtype=torch.float64, device=device) imag = torch.randn(size, dtype=torch.float64, device=device) return torch.complex(real, imag) - if _complex32_dtype() is not None and dtype == _complex32_dtype(): - stacked = torch.randn((size, 2), dtype=torch.float16, device=device) - return torch.view_as_complex(stacked) raise TypeError(f"Unsupported dtype in test: {dtype}") @@ -113,8 +101,6 @@ def _extra_gather_tolerance(value_dtype): ("scalar16", torch.float16, torch.int64), ("scalar16", torch.bfloat16, torch.int32), ("scalar16", torch.bfloat16, torch.int64), - ("complex16_pair", torch.float16, torch.int32), - ("complex16_pair", torch.float16, torch.int64), ("complex64", torch.complex64, torch.int32), ("complex64", torch.complex64, torch.int64), ] @@ -122,8 +108,6 @@ def _extra_gather_tolerance(value_dtype): "half_i64", "bf16_i32", "bf16_i64", - "c16f_i32", - "c16f_i64", "c64_i32", "c64_i64", ] @@ -144,6 +128,21 @@ def test_gather_matches_indexing(dense_size, nnz, dtype, index_dtype): assert torch.equal(ref, got) +@pytest.mark.gather +@pytest.mark.parametrize("index_dtype", INDEX_DTYPES, ids=INDEX_DTYPE_IDS) +def test_gather_complex128_matches_indexing(index_dtype): + device = torch.device("cuda") + dense_size = 4096 + nnz = 1024 + real = torch.randn(dense_size, dtype=torch.float64, device=device) + imag = torch.randn(dense_size, dtype=torch.float64, device=device) + dense = torch.complex(real, imag) + indices = torch.randperm(dense_size, device=device)[:nnz].to(index_dtype) + ref = dense.index_select(0, indices.to(torch.int64)) + got = flagsparse_gather(dense, indices) + assert torch.allclose(got, ref, atol=1e-10, rtol=1e-8) + + @pytest.mark.scatter @pytest.mark.parametrize("dense_size, nnz", GATHER_SCATTER_SHAPES) @pytest.mark.parametrize("dtype_name,dtype", SCATTER_DTYPE_CASES, ids=SCATTER_DTYPE_IDS) @@ -323,94 +322,6 @@ def test_gather_cupy_same_backend_out_float16_i64(backend): assert torch.allclose(_as_torch_tensor(out), reference, atol=5e-3, rtol=5e-3) -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -@pytest.mark.parametrize("backend", ["torch", "cupy"]) -def test_gather_cupy_same_backend_out_pair_complex32(backend): - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_t = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - indices_t = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - reference = dense_t.index_select(0, indices_t) - - dense_in = _to_backend_tensor(dense_t, backend) - indices_in = _to_backend_tensor(indices_t, backend) - out = _to_backend_tensor(torch.empty_like(reference), backend) - result = flagsparse_gather_cupy(dense_in, indices_in, out=out) - - assert result is out - assert torch.allclose(_as_torch_tensor(out), reference, atol=5e-3, rtol=5e-3) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_native_complex32_out_matches_reference(): - native_dtype = _complex32_dtype() - _skip_unavailable_dtype("complex32", native_dtype) - - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_pair = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - dense_native = torch.view_as_complex(dense_pair.contiguous()) - indices = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - reference = dense_native.index_select(0, indices) - - out = torch.empty_like(reference) - result = flagsparse_gather_cupy(dense_native, indices, out=out) - - assert result is out - assert out.dtype == native_dtype - assert torch.allclose( - torch.view_as_real(out).contiguous(), - torch.view_as_real(reference).contiguous(), - atol=5e-3, - rtol=5e-3, - ) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_native_complex32_matches_reference_and_pair_layout(): - native_dtype = _complex32_dtype() - _skip_unavailable_dtype("complex32", native_dtype) - - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_pair = torch.randn(dense_size, 2, dtype=torch.float16, device=device) - dense_native = torch.view_as_complex(dense_pair.contiguous()) - indices = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - - reference = dense_native.index_select(0, indices) - reference_pair = torch.view_as_real(reference).contiguous() - pair_got = flagsparse_gather_cupy(dense_pair, indices) - native_got = flagsparse_gather_cupy(dense_native, indices) - cusparse_values, _, _ = cusparse_spmv_gather_cupy(dense_native, indices) - - atol, rtol = 5e-3, 5e-3 - assert pair_got.shape == reference_pair.shape - assert pair_got.dtype == torch.float16 - assert native_got.shape == reference.shape - assert native_got.dtype == native_dtype - assert cusparse_values.shape == reference.shape - assert cusparse_values.dtype == native_dtype - assert torch.allclose(pair_got, reference_pair, atol=atol, rtol=rtol) - assert torch.allclose( - torch.view_as_real(native_got).contiguous(), - reference_pair, - atol=atol, - rtol=rtol, - ) - assert torch.allclose( - torch.view_as_real(cusparse_values).contiguous(), - reference_pair, - atol=atol, - rtol=rtol, - ) - - @pytest.mark.gather @pytest.mark.skipif(cp is None, reason="CuPy required") def test_gather_cupy_int64_auto_fallback_to_int32(monkeypatch): diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index 94788f2..d44d7b5 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -1,7 +1,7 @@ import pytest import torch -from flagsparse import flagsparse_spsv_csr +from flagsparse import flagsparse_spsv_coo, flagsparse_spsv_csr from tests.pytest.param_shapes import SPSV_N @@ -16,6 +16,7 @@ pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +<<<<<<< HEAD COMPLEX32_DTYPE = getattr(torch, "complex32", None) if COMPLEX32_DTYPE is None: COMPLEX32_DTYPE = getattr(torch, "chalf", None) @@ -26,6 +27,13 @@ SUPPORTED_COMPLEX_DTYPES.append(torch.complex64) SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] +======= +SUPPORTED_COMPLEX_DTYPES = [torch.complex64, torch.complex128] + +SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] +NON_TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +>>>>>>> 5a83e0f (test) def _dtype_id(dtype): @@ -33,8 +41,11 @@ def _dtype_id(dtype): def _tol(dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return 5e-3, 5e-3 +======= +>>>>>>> 5a83e0f (test) if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-3 return 1e-10, 1e-8 @@ -43,30 +54,41 @@ def _tol(dtype): def _rand_like(dtype, shape, device): if dtype in (torch.float32, torch.float64): return torch.randn(shape, dtype=dtype, device=device) +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: pair = torch.randn((*shape, 2), dtype=torch.float16, device=device) * 0.1 return torch.view_as_complex(pair) base = torch.float32 +======= + base = torch.float32 if dtype == torch.complex64 else torch.float64 +>>>>>>> 5a83e0f (test) r = torch.randn(shape, dtype=base, device=device) i = torch.randn(shape, dtype=base, device=device) return torch.complex(r, i) def _ref_dtype(dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return torch.complex64 +======= +>>>>>>> 5a83e0f (test) return dtype def _safe_cast_tensor(tensor, dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: real = tensor.real.to(torch.float16) imag = tensor.imag.to(torch.float16) return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) +======= +>>>>>>> 5a83e0f (test) return tensor.to(dtype) def _cmp_view(tensor, dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return tensor.to(torch.complex64) return tensor @@ -75,6 +97,14 @@ def _cmp_view(tensor, dtype): def _build_lower_triangular(n, dtype, device): off = _rand_like(dtype, (n, n), device) * 0.02 A = torch.tril(off) +======= + return tensor + + +def _build_triangular(n, dtype, device, lower=True): + off = _rand_like(dtype, (n, n), device) * 0.02 + A = torch.tril(off) if lower else torch.triu(off) +>>>>>>> 5a83e0f (test) if torch.is_complex(A): diag = (torch.rand(n, device=device, dtype=A.real.dtype) + 2.0).to(A.real.dtype) A = A + torch.diag(torch.complex(diag, torch.zeros_like(diag))) @@ -87,8 +117,12 @@ def _build_lower_triangular(n, dtype, device): def _cupy_csr_from_torch(data, indices, indptr, shape): if cp is None or cpx_sparse is None: return None +<<<<<<< HEAD data_ref = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) +======= + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) +>>>>>>> 5a83e0f (test) idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous())) ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous())) return cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) @@ -97,12 +131,18 @@ def _cupy_csr_from_torch(data, indices, indptr, shape): def _cupy_ref_spsv(A_cp, b_t, *, lower, unit_diagonal=False): if cp is None or cpx_spsolve_triangular is None: return None +<<<<<<< HEAD b_ref = b_t.to(torch.complex64) if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE else b_t b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE: return x_t.to(torch.complex64) +======= + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_t.contiguous())) + x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) +>>>>>>> 5a83e0f (test) return x_t.to(b_t.dtype) @@ -143,11 +183,19 @@ def test_spsv_csr_lower_matches_dense(n, dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) +<<<<<<< HEAD @pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) @pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): device = torch.device("cuda") A = _build_lower_triangular(n, dtype, device) +======= +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) +>>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) x_ref = torch.linalg.solve_triangular( A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False @@ -174,10 +222,17 @@ def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) +<<<<<<< HEAD @pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) def test_spsv_csr_trans_int32_supported_combos(n, dtype): device = torch.device("cuda") A = _build_lower_triangular(n, dtype, device) +======= +@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_trans_int32_supported_combos(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) +>>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) A_ref = A.to(_ref_dtype(dtype)) b_ref = b.to(_ref_dtype(dtype)) @@ -205,12 +260,24 @@ def test_spsv_csr_trans_int32_supported_combos(n, dtype): @pytest.mark.spsv +<<<<<<< HEAD @pytest.mark.skipif(cp is None or cpx_spsolve_triangular is None, reason="CuPy/cuSPARSE required") @pytest.mark.parametrize("n", SPSV_N) @pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): device = torch.device("cuda") A = _build_lower_triangular(n, dtype, device) +======= +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_non_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) +>>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) Asp = A.to_sparse_csr() @@ -224,11 +291,187 @@ def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): ) x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) +<<<<<<< HEAD +======= + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + +>>>>>>> 5a83e0f (test) x_trans = flagsparse_spsv_csr( data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=True ) x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=False, unit_diagonal=False) rtol, atol = _tol(dtype) +<<<<<<< HEAD assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) +======= + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_upper_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=True + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_trans_upper_int32_supported_combos(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + A_ref = A.to(_ref_dtype(dtype)) + b_ref = b.to(_ref_dtype(dtype)) + x_ref = torch.linalg.solve_triangular( + A_ref.transpose(-2, -1), b_ref.unsqueeze(-1), upper=False + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=True, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_upper_non_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_non = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=False, unit_diagonal=False, transpose=False + ) + x_non_ref = _cupy_ref_spsv(A_cp, b, lower=False, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_upper_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_trans = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=False, unit_diagonal=False, transpose=True + ) + x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=True, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +def test_spsv_coo_transpose_complex128_routes_through_csr(n): + device = torch.device("cuda") + dtype = torch.complex128 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.transpose(-2, -1), b.unsqueeze(-1), upper=True + ).squeeze(-1) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=True, + coo_mode="auto", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) +>>>>>>> 5a83e0f (test) diff --git a/tests/test_gather.py b/tests/test_gather.py index 481138a..b00a6c4 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -16,7 +16,7 @@ (524_288, 16_384), (1_048_576, 65_536), ] -DEFAULT_VALUE_DTYPES = "float16,bfloat16,float32,float64,complex32,complex64,complex128" +DEFAULT_VALUE_DTYPES = "float16,bfloat16,float32,float64,complex64,complex128" DEFAULT_INDEX_DTYPES = "int32,int64" WARMUP = 20 ITERS = 200 @@ -48,7 +48,6 @@ def _parse_value_dtypes(raw): "bfloat16", "float32", "float64", - "complex32", "complex64", "complex128", } @@ -164,7 +163,7 @@ def _collect_samples(case_id, expected, flagsparse_out, limit): def _dtype_mode(value_dtype_req): - if value_dtype_req in ("float16", "bfloat16", "complex32", "complex64"): + if value_dtype_req in ("float16", "bfloat16", "complex64"): return "gather_cupy" return "gather_triton" @@ -193,8 +192,6 @@ def _build_dense(value_dtype_req, dense_size, device): real = torch.randn(dense_size, dtype=torch.float64, device=device) imag = torch.randn(dense_size, dtype=torch.float64, device=device) return torch.complex(real, imag) - if value_dtype_req == "complex32": - return torch.randn(dense_size, 2, dtype=torch.float16, device=device) raise ValueError(f"Unsupported value dtype request: {value_dtype_req}") @@ -206,13 +203,12 @@ def _effective_dtype_name(value_dtype_req): "float64": "float64", "complex64": "complex64", "complex128": "complex128", - "complex32": "complex16_pair_f16", } return mapping[value_dtype_req] def _tolerance(value_dtype_req): - if value_dtype_req in ("float16", "complex32"): + if value_dtype_req == "float16": return 5e-3, 5e-3 if value_dtype_req in ("bfloat16",): return 1e-2, 1e-2 @@ -230,10 +226,10 @@ def _check_dtype_supported(value_dtype_req): def _is_supported_extra_gather_combo(value_dtype_req, index_dtype): # Required extra gather combos only: - # Half+Int32/Int64, Bfloat16+Int32/Int64, Complex32+Int32/Int64, Complex64+Int32/Int64 + # Half+Int32/Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64 if value_dtype_req == "float16": return index_dtype in (torch.int32, torch.int64) - if value_dtype_req in ("bfloat16", "complex32", "complex64"): + if value_dtype_req in ("bfloat16", "complex64"): return index_dtype in (torch.int32, torch.int64) # Original gather path dtypes keep original behavior. return True diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 162dfc8..cdc7a31 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -22,8 +22,8 @@ cpx_sparse = None cpx_spsolve_triangular = None -VALUE_DTYPES = [torch.float32, torch.float64] -INDEX_DTYPES = [torch.int32] +VALUE_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +INDEX_DTYPES = [torch.int32, torch.int64] TEST_SIZES = [256, 512, 1024, 2048] WARMUP = 5 ITERS = 20 @@ -44,6 +44,15 @@ CSR_FULL_VALUE_DTYPES.append(torch.complex64) CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] +# CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) +CSR_FULL_VALUE_DTYPES = [ + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, +] +CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] + def _dtype_name(dtype): return str(dtype).replace("torch.", "") @@ -64,8 +73,11 @@ def _fmt_err(v): def _tol_for_dtype(dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return 5e-3, 5e-3 +======= +>>>>>>> 5a83e0f (test) if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 return 1e-12, 1e-10 @@ -74,22 +86,30 @@ def _tol_for_dtype(dtype): def _randn_by_dtype(n, dtype, device): if dtype in (torch.float32, torch.float64): return torch.randn(n, dtype=dtype, device=device) +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: pair = torch.randn((n, 2), dtype=torch.float16, device=device) * 0.1 return torch.view_as_complex(pair) base = torch.float32 +======= + base = torch.float32 if dtype == torch.complex64 else torch.float64 +>>>>>>> 5a83e0f (test) real = torch.randn(n, dtype=base, device=device) imag = torch.randn(n, dtype=base, device=device) return torch.complex(real, imag) def _dense_ref_dtype(dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: return torch.complex64 +======= +>>>>>>> 5a83e0f (test) return dtype def _tensor_from_scalar_values(values, dtype, device): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: real = torch.clamp( torch.tensor(values, dtype=torch.float32, device=device), @@ -98,18 +118,24 @@ def _tensor_from_scalar_values(values, dtype, device): ).to(torch.float16) imag = torch.zeros_like(real) return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) +======= +>>>>>>> 5a83e0f (test) return torch.tensor(values, dtype=dtype, device=device) def _safe_cast_tensor(tensor, dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: real = torch.clamp(tensor.real, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) imag = torch.clamp(tensor.imag, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) +======= +>>>>>>> 5a83e0f (test) return tensor.to(dtype) def _cast_real_tensor_to_value_dtype(values, value_dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: real = torch.clamp(values, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) imag = torch.zeros_like(real) @@ -120,12 +146,45 @@ def _cast_real_tensor_to_value_dtype(values, value_dtype): def _cupy_ref_inputs(data, b): if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: return data.to(torch.complex64), b.to(torch.complex64) +======= + return values.to(value_dtype) + + +def _matrix_market_value(parts, mm_field): + if mm_field == "complex": + if len(parts) < 4: + raise ValueError("MatrixMarket complex entry requires real and imag parts") + return complex(float(parts[2]), float(parts[3])) + if len(parts) >= 3: + return float(parts[2]) + if mm_field == "pattern": + return 1.0 + raise ValueError("MatrixMarket entry is missing a numeric value") + + +def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): + if op_mode == "TRANS": + A_eff = A.transpose(0, 1) + upper = lower + else: + A_eff = A + upper = not lower + return torch.linalg.solve_triangular( + A_eff, b.unsqueeze(1), upper=upper + ).squeeze(1) + + +def _cupy_ref_inputs(data, b): +>>>>>>> 5a83e0f (test) return data, b def _compare_view(tensor, value_dtype): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: return tensor.to(torch.complex64) +======= +>>>>>>> 5a83e0f (test) return tensor @@ -147,7 +206,7 @@ def _allow_dense_pytorch_ref(shape, dtype): def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True): - """Build a well-conditioned triangular CSR (float32/float64).""" + """Build a well-conditioned triangular CSR for real and complex dtypes.""" max_bandwidth = max(4, min(n, 16)) rows_host = [] cols_host = [] @@ -156,10 +215,17 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True base_real_dtype = torch.float32 elif value_dtype == torch.float64: base_real_dtype = torch.float64 +<<<<<<< HEAD elif COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: base_real_dtype = torch.float16 else: base_real_dtype = torch.float32 +======= + elif value_dtype == torch.complex64: + base_real_dtype = torch.float32 + else: + base_real_dtype = torch.float64 +>>>>>>> 5a83e0f (test) for i in range(n): if lower: @@ -176,23 +242,44 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True off_cols = [off_cand[j] for j in perm] else: off_cols = [] - off_vals_real = torch.randn(len(off_cols), dtype=base_real_dtype).mul_(0.01) - sum_abs = float(torch.sum(torch.abs(off_vals_real)).item()) if off_vals_real.numel() else 0.0 - diag_val = sum_abs + 1.0 + if value_dtype in (torch.complex64, torch.complex128): + off_vals = torch.complex( + torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), + torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), + ) + sum_abs = ( + float(torch.sum(torch.abs(off_vals)).item()) if off_vals.numel() else 0.0 + ) + diag_imag = float( + torch.randn((), dtype=base_real_dtype, device=device).mul_(0.05).item() + ) + diag_val = complex(sum_abs + 1.0, diag_imag) + off_vals_host = [complex(v) for v in off_vals.cpu().tolist()] + else: + off_vals = torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01) + sum_abs = ( + float(torch.sum(torch.abs(off_vals)).item()) if off_vals.numel() else 0.0 + ) + diag_val = sum_abs + 1.0 + off_vals_host = off_vals.cpu().tolist() rows_host.append(i) cols_host.append(diag_col) vals_host.append(diag_val) - for c, v in zip(off_cols, off_vals_real.tolist()): + for c, v in zip(off_cols, off_vals_host): rows_host.append(i) cols_host.append(int(c)) vals_host.append(v) rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) +<<<<<<< HEAD vals_t = _cast_real_tensor_to_value_dtype( torch.tensor(vals_host, dtype=base_real_dtype, device=device), value_dtype, ) +======= + vals_t = torch.tensor(vals_host, dtype=value_dtype, device=device) +>>>>>>> 5a83e0f (test) order = torch.argsort(rows_t * max(1, n) + cols_t) rows_t = rows_t[order] cols_t = cols_t[order] @@ -257,7 +344,11 @@ def _csr_transpose(data, indices, indptr, shape): return data_t, col_t.to(torch.int64), indptr_t +<<<<<<< HEAD def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): +======= +def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None, lower=True): +>>>>>>> 5a83e0f (test) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open(file_path, "r", encoding="utf-8") as f: @@ -302,28 +393,26 @@ def _accum(r, c, v): continue r = int(parts[0]) - 1 c = int(parts[1]) - 1 - if len(parts) >= 3: - v = float(parts[2]) - elif mm_field == "pattern": - v = 1.0 - else: - continue + v = _matrix_market_value(parts, mm_field) _accum(r, c, v) - if mm_symmetry in ("symmetric", "hermitian") and r != c: + if mm_symmetry == "symmetric" and r != c: _accum(c, r, v) + elif mm_symmetry == "hermitian" and r != c: + _accum(c, r, v.conjugate() if isinstance(v, complex) else v) elif mm_symmetry == "skew-symmetric" and r != c: _accum(c, r, -v) for r in range(n_rows): row = row_maps[r] - lower_row = {} + tri_row = {} off_abs_sum = 0.0 for c, v in row.items(): - if c < r: - lower_row[c] = lower_row.get(c, 0.0) + v + keep = c < r if lower else c > r + if keep: + tri_row[c] = tri_row.get(c, 0.0) + v off_abs_sum += abs(v) - lower_row[r] = off_abs_sum + 1.0 - row_maps[r] = lower_row + tri_row[r] = off_abs_sum + 1.0 + row_maps[r] = tri_row cols_s = [] vals_s = [] @@ -360,6 +449,7 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: data_ref = data.to(torch.complex64) x_ref = x_true.to(torch.complex64) @@ -380,6 +470,8 @@ def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): ) return _safe_cast_tensor(b_ref, x_true.dtype) raise ValueError("op_mode must be 'NON' or 'TRANS'") +======= +>>>>>>> 5a83e0f (test) if op_mode == "NON": b, _ = fs.flagsparse_spmv_csr( data, indices, indptr, x_true, shape, return_time=True @@ -408,6 +500,7 @@ def _cupy_spsolve_lower_csr_or_coo( b, warmup, iters, + lower, ): """Triangular solve via CuPy: CSR or COO storage. Returns (ms, x_torch) or (None, None).""" if ( @@ -439,7 +532,7 @@ def _cupy_spsolve_lower_csr_or_coo( A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) for _ in range(warmup): _ = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) cp.cuda.runtime.deviceSynchronize() t0 = cp.cuda.Event() @@ -447,22 +540,30 @@ def _cupy_spsolve_lower_csr_or_coo( t0.record() for _ in range(iters): x_cu = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) t1.record() t1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(t0, t1) / iters x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) +<<<<<<< HEAD if COMPLEX32_DTYPE is not None and b.dtype == COMPLEX32_DTYPE: x_cu_t = x_cu_t.to(torch.complex64) else: x_cu_t = x_cu_t.to(b.dtype) +======= + x_cu_t = x_cu_t.to(b.dtype) +>>>>>>> 5a83e0f (test) return cupy_ms, x_cu_t except Exception: return None, None +<<<<<<< HEAD def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): +======= +def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): +>>>>>>> 5a83e0f (test) if ( cp is None or cpx_sparse is None @@ -482,10 +583,17 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) if op_mode == "TRANS": A_eff = A_cp.transpose().tocsr() +<<<<<<< HEAD lower_eff = False else: A_eff = A_cp lower_eff = True +======= + lower_eff = not lower + else: + A_eff = A_cp + lower_eff = lower +>>>>>>> 5a83e0f (test) for _ in range(WARMUP): _ = cpx_spsolve_triangular( @@ -508,7 +616,11 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): return None, None +<<<<<<< HEAD def run_spsv_synthetic_all(): +======= +def run_spsv_synthetic_all(lower=True): +>>>>>>> 5a83e0f (test) if not torch.cuda.is_available(): print("CUDA is not available. Please run on a GPU-enabled system.") return @@ -519,10 +631,11 @@ def run_spsv_synthetic_all(): print(sep) print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Warmup: {WARMUP} | Iters: {ITERS}") + print(f"Triangle: {'LOWER' if lower else 'UPPER'}") print() hdr = ( - f"{'Fmt':>5} {'N':>6} {'FlagSparse(ms)':>14} {'PyTorch(ms)':>12} {'CuPy(ms)':>10} " + f"{'Fmt':>5} {'opA':>5} {'N':>6} {'FlagSparse(ms)':>14} {'PyTorch(ms)':>12} {'CuPy(ms)':>10} " f"{'FS/PT':>8} {'FS/CU':>8} {'Status':>8} {'Err(PT)':>12} {'Err(CU)':>12}" ) @@ -540,101 +653,121 @@ def run_spsv_synthetic_all(): print("-" * 110) for n in TEST_SIZES: for fmt in ("CSR", "COO"): - data, indices, indptr, shape = _build_random_triangular_csr( - n, value_dtype, index_dtype, device, lower=True + op_modes = ( + _supported_csr_full_ops(value_dtype, index_dtype) + if fmt == "CSR" + else ["NON"] ) - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr, shape - ) - x_true = torch.randn(n, dtype=value_dtype, device=device) - b = A_dense @ x_true - - torch.cuda.synchronize() - if fmt == "CSR": - x, t_ms = fs.flagsparse_spsv_csr( - data, - indices, - indptr, - b, - shape, - lower=True, - return_time=True, + for op_mode in op_modes: + data, indices, indptr, shape = _build_random_triangular_csr( + n, value_dtype, index_dtype, device, lower=lower ) - else: - dc, rr, cc = _csr_to_coo(data, indices, indptr, shape) - x, t_ms = fs.flagsparse_spsv_coo( - dc, - rr, - cc, - b, - shape, - lower=True, - coo_mode="auto", - return_time=True, + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr, shape ) - torch.cuda.synchronize() - - A_ref = A_dense - b_ref = b - torch.cuda.synchronize() - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - e0.record() - x_pt = torch.linalg.solve( - A_ref, b_ref.unsqueeze(1) - ).squeeze(1) - e1.record() - torch.cuda.synchronize() - pytorch_ms = e0.elapsed_time(e1) - err_pt = float(torch.max(torch.abs(x - x_pt)).item()) if n > 0 else 0.0 - - cupy_ms = None - err_cu = None - x_cu_t = None - if value_dtype in (torch.float32, torch.float64): - cupy_ms, x_cu_t = _cupy_spsolve_lower_csr_or_coo( - fmt, - data, - indices, - indptr, - shape, - b, - WARMUP, - ITERS, + x_true = _randn_by_dtype(n, value_dtype, device) + if fmt == "CSR": + b = _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode) + else: + b = A_dense @ x_true + + torch.cuda.synchronize() + if fmt == "CSR": + x, t_ms = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=(op_mode == "TRANS"), + return_time=True, + ) + else: + dc, rr, cc = _csr_to_coo(data, indices, indptr, shape) + x, t_ms = fs.flagsparse_spsv_coo( + dc, + rr, + cc, + b, + shape, + lower=lower, + coo_mode="auto", + return_time=True, + ) + torch.cuda.synchronize() + + A_ref = A_dense + b_ref = b + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_pt = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode=op_mode ) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = float(torch.max(torch.abs(x - x_pt)).item()) if n > 0 else 0.0 + + cupy_ms = None + err_cu = None + x_cu_t = None + if fmt == "CSR": + cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( + data, indices, indptr, shape, b, op_mode, lower + ) + elif value_dtype in ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ): + cupy_ms, x_cu_t = _cupy_spsolve_lower_csr_or_coo( + fmt, + data, + indices, + indptr, + shape, + b, + WARMUP, + ITERS, + lower, + ) if x_cu_t is not None and n > 0: err_cu = float( torch.max(torch.abs(x - x_cu_t)).item() ) - atol, rtol = _tol_for_dtype(value_dtype) - ok_pt = torch.allclose(x, x_pt, atol=atol, rtol=rtol) - ok_cu = ( - True - if x_cu_t is None - else torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) - ) - ok = ok_pt or ok_cu - status = "PASS" if ok else "FAIL" - if not ok: - failed += 1 - total += 1 - - fs_vs_pt = ( - (pytorch_ms / t_ms) if (t_ms and t_ms > 0) else None - ) - fs_vs_cu = ( - (cupy_ms / t_ms) - if (cupy_ms is not None and t_ms and t_ms > 0) - else None - ) - print( - f"{fmt:>5} {n:>6} {_fmt_ms(t_ms):>14} {_fmt_ms(pytorch_ms):>12} " - f"{_fmt_ms(cupy_ms):>10} " - f"{(f'{fs_vs_pt:.2f}x' if fs_vs_pt is not None else 'N/A'):>8} " - f"{(f'{fs_vs_cu:.2f}x' if fs_vs_cu is not None else 'N/A'):>8} " - f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" - ) + atol, rtol = _tol_for_dtype(value_dtype) + ok_pt = torch.allclose(x, x_pt, atol=atol, rtol=rtol) + ok_cu = ( + True + if x_cu_t is None + else torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) + ) + ok = ok_pt or ok_cu + status = "PASS" if ok else "FAIL" + if not ok: + failed += 1 + total += 1 + + fs_vs_pt = ( + (pytorch_ms / t_ms) if (t_ms and t_ms > 0) else None + ) + fs_vs_cu = ( + (cupy_ms / t_ms) + if (cupy_ms is not None and t_ms and t_ms > 0) + else None + ) + print( + f"{fmt:>5} {op_mode:>5} {n:>6} {_fmt_ms(t_ms):>14} {_fmt_ms(pytorch_ms):>12} " + f"{_fmt_ms(cupy_ms):>10} " + f"{(f'{fs_vs_pt:.2f}x' if fs_vs_pt is not None else 'N/A'):>8} " + f"{(f'{fs_vs_cu:.2f}x' if fs_vs_cu is not None else 'N/A'):>8} " + f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" + ) print("-" * 110) print() @@ -643,32 +776,32 @@ def run_spsv_synthetic_all(): print(sep) -def _run_one_csv_row_csr(path, value_dtype, index_dtype, device): +def _run_one_csv_row_csr(path, value_dtype, index_dtype, device, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device + path, dtype=value_dtype, device=device, lower=lower ) indices = indices.to(index_dtype) n_rows, n_cols = shape - x_true = torch.randn(n_rows, dtype=value_dtype, device=device) + x_true = _randn_by_dtype(n_rows, value_dtype, device) b, _ = fs.flagsparse_spmv_csr( data, indices, indptr, x_true, shape, return_time=True ) x, t_ms = fs.flagsparse_spsv_csr( - data, indices, indptr, b, shape, lower=True, return_time=True + data, indices, indptr, b, shape, lower=lower, return_time=True ) return _finalize_csv_row( path, value_dtype, index_dtype, data, indices, indptr, shape, - x, t_ms, b, n_rows, n_cols, + x, t_ms, b, n_rows, n_cols, lower=lower, ) -def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): +def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device + path, dtype=value_dtype, device=device, lower=lower ) indices = indices.to(index_dtype) n_rows, n_cols = shape - x_true = torch.randn(n_rows, dtype=value_dtype, device=device) + x_true = _randn_by_dtype(n_rows, value_dtype, device) b, _ = fs.flagsparse_spmv_csr( data, indices, indptr, x_true, shape, return_time=True ) @@ -681,7 +814,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): c_in, b, shape, - lower=True, + lower=lower, coo_mode=coo_mode, return_time=True, ) @@ -698,6 +831,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode): b, n_rows, n_cols, + lower=lower, nnz_display=int(d_in.numel()), cupy_coo_data=d_in, cupy_coo_row=r_in, @@ -719,6 +853,7 @@ def _finalize_csv_row( n_rows, n_cols, *, + lower=True, nnz_display=None, cupy_coo_data=None, cupy_coo_row=None, @@ -741,7 +876,13 @@ def _finalize_csv_row( e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() +<<<<<<< HEAD x_ref = torch.linalg.solve(A_ref, b_ref.unsqueeze(1)).squeeze(1) +======= + x_ref = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode="NON" + ) +>>>>>>> 5a83e0f (test) x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() @@ -814,7 +955,7 @@ def _finalize_csv_row( ) for _ in range(WARMUP): _ = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) cp.cuda.runtime.deviceSynchronize() c0 = cp.cuda.Event() @@ -822,7 +963,7 @@ def _finalize_csv_row( c0.record() for _ in range(ITERS): x_cu = cpx_spsolve_triangular( - A_cp, b_cp, lower=True, unit_diagonal=False + A_cp, b_cp, lower=lower, unit_diagonal=False ) c1.record() c1.synchronize() @@ -865,9 +1006,15 @@ def _finalize_csv_row( return row, pt_skip_reason +<<<<<<< HEAD def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, device=device +======= +def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower +>>>>>>> 5a83e0f (test) ) indices = indices.to(index_dtype) indptr = indptr.to(index_dtype) @@ -880,7 +1027,11 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): indptr, b, shape, +<<<<<<< HEAD lower=True, +======= + lower=lower, +>>>>>>> 5a83e0f (test) transpose=(op_mode == "TRANS"), return_time=True, ) @@ -898,6 +1049,10 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): b, n_rows, n_cols, +<<<<<<< HEAD +======= + lower=lower, +>>>>>>> 5a83e0f (test) ) @@ -915,6 +1070,10 @@ def _finalize_csv_row_csr_full( b, n_rows, n_cols, +<<<<<<< HEAD +======= + lower=True, +>>>>>>> 5a83e0f (test) ): atol, rtol = _tol_for_dtype(value_dtype) @@ -927,12 +1086,21 @@ def _finalize_csv_row_csr_full( A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr.to(torch.int64), shape ).to(_dense_ref_dtype(value_dtype)) +<<<<<<< HEAD A_ref = A_dense.transpose(0, 1) if op_mode == "TRANS" else A_dense +======= +>>>>>>> 5a83e0f (test) e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() +<<<<<<< HEAD x_ref = torch.linalg.solve(A_ref, b.to(A_ref.dtype).unsqueeze(1)).squeeze(1) +======= + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) +>>>>>>> 5a83e0f (test) x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() @@ -959,7 +1127,11 @@ def _finalize_csv_row_csr_full( ok_cu = False x_cu_t = None cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( +<<<<<<< HEAD data, indices, indptr, shape, b, op_mode +======= + data, indices, indptr, shape, b, op_mode, lower +>>>>>>> 5a83e0f (test) ) if x_cu_t is not None: x_cmp = _compare_view(x, value_dtype) @@ -994,6 +1166,7 @@ def _finalize_csv_row_csr_full( return row, pt_skip_reason +<<<<<<< HEAD def run_all_supported_spsv_csr_csv(mtx_paths, csv_path): if not torch.cuda.is_available(): print("CUDA is not available.") @@ -1103,11 +1276,125 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path): def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto"): +======= +def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): +>>>>>>> 5a83e0f (test) if not torch.cuda.is_available(): print("CUDA is not available.") return device = torch.device("cuda") rows_out = [] + for value_dtype in CSR_FULL_VALUE_DTYPES: + for index_dtype in CSR_FULL_INDEX_DTYPES: + op_modes = _supported_csr_full_ops(value_dtype, index_dtype) + for op_mode in op_modes: + print("=" * 150) + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" + ) + print( + "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." + ) + print( + "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " + "PASS if either error within tolerance." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 150) + for path in mtx_paths: + try: + row, pt_skip = _run_one_csv_row_csr_full( + path, value_dtype, index_dtype, op_mode, device, lower=lower + ) + rows_out.append(row) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + n_rows, n_cols = row["n_rows"], row["n_cols"] + nnz = row["nnz"] + t_ms = row["triton_ms"] + cupy_ms = row["cusparse_ms"] + pytorch_ms = row["pytorch_ms"] + err_pt, err_cu = row["err_pt"], row["err_cu"] + status = row["status"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " + f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " + f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " + f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + ) + if pt_skip: + print(f" NOTE: {pt_skip}") + except Exception as e: + err_msg = str(e) + status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" + rows_out.append( + { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "triton_ms": None, + "pytorch_ms": None, + "cusparse_ms": None, + "csc_ms": None, + "status": status, + "err_pt": None, + "err_cu": None, + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " + f"{'N/A':>7} {'N/A':>7} " + f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" {status}: {e}") + print("-" * 150) + fieldnames = [ + "matrix", + "value_dtype", + "index_dtype", + "opA", + "n_rows", + "n_cols", + "nnz", + "triton_ms", + "pytorch_ms", + "cusparse_ms", + "csc_ms", + "status", + "err_pt", + "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + +def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", lower=True): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + if not use_coo: + run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=lower) + return + device = torch.device("cuda") + rows_out = [] label = "COO" if use_coo else "CSR" cu_col = "COO(ms)" if use_coo else "CSR(ms)" fs_cu_hdr = "FS/COO" if use_coo else "FS/CSR" @@ -1117,7 +1404,7 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") print("=" * 150) print( f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | {label}" - + (f" coo_mode={coo_mode}" if use_coo else "") + + (f" triA={'LOWER' if lower else 'UPPER'}" if not use_coo else f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}") ) if use_coo: print( @@ -1143,11 +1430,11 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto") try: if use_coo: row, pt_skip = _run_one_csv_row_coo( - path, value_dtype, index_dtype, device, coo_mode + path, value_dtype, index_dtype, device, coo_mode, lower=lower ) else: row, pt_skip = _run_one_csv_row_csr( - path, value_dtype, index_dtype, device + path, value_dtype, index_dtype, device, lower=lower ) rows_out.append(row) name = os.path.basename(path)[:27] @@ -1254,10 +1541,16 @@ def main(): choices=["auto", "direct", "csr"], help="COO mode for --csv-coo (default: auto)", ) + parser.add_argument( + "--upper", + action="store_true", + help="Use upper-triangular inputs instead of the default lower-triangular inputs", + ) args = parser.parse_args() + lower = not args.upper if args.synthetic: - run_spsv_synthetic_all() + run_spsv_synthetic_all(lower=lower) return paths = [] @@ -1272,7 +1565,11 @@ def main(): if not paths: print("No .mtx files found for --csv-csr") return +<<<<<<< HEAD run_all_supported_spsv_csr_csv(paths, args.csv_csr) +======= + run_all_supported_spsv_csr_csv(paths, args.csv_csr, lower=lower) +>>>>>>> 5a83e0f (test) return if args.csv_coo: if not paths: @@ -1281,7 +1578,7 @@ def main(): print("No .mtx files found for --csv-coo") return run_all_dtypes_spsv_csv( - paths, args.csv_coo, use_coo=True, coo_mode=args.coo_mode + paths, args.csv_coo, use_coo=True, coo_mode=args.coo_mode, lower=lower ) return From f2b0b5bf2385295fd73b6629bb9e3d5963b52c11 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 20 Apr 2026 20:56:03 +0800 Subject: [PATCH 03/22] Support CONJ transpose mode in spsv --- src/flagsparse/sparse_operations/spsv.py | 253 +++++------- tests/pytest/test_spsv_csr_accuracy.py | 210 +++++----- tests/test_scatter.py | 1 - tests/test_spsv.py | 491 +++++++++-------------- 4 files changed, 379 insertions(+), 576 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index dd867bc..5f6174c 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -17,23 +17,25 @@ ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) -SPSV_NON_TRANS_PRIMARY_COMBOS = ( +SPSV_NON_TRANS_SUPPORTED_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), (torch.complex64, torch.int32), (torch.complex128, torch.int32), -) -SPSV_NON_TRANS_EXTENDED_COMBOS = ( (torch.float32, torch.int64), (torch.float64, torch.int64), (torch.complex64, torch.int64), (torch.complex128, torch.int64), ) -SPSV_TRANS_PRIMARY_COMBOS = ( +SPSV_TRANS_SUPPORTED_COMBOS = ( (torch.float32, torch.int32), (torch.float64, torch.int32), (torch.complex64, torch.int32), (torch.complex128, torch.int32), + (torch.float32, torch.int64), + (torch.float64, torch.int64), + (torch.complex64, torch.int64), + (torch.complex128, torch.int64), ) SPSV_PROMOTE_FP32_TO_FP64 = str( os.environ.get("FLAGSPARSE_SPSV_PROMOTE_FP32_TO_FP64", "0") @@ -64,33 +66,24 @@ def _csr_to_dense(data, indices, indptr, shape): def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" - if (data_dtype, index_dtype) in SPSV_NON_TRANS_PRIMARY_COMBOS: - return - if (data_dtype, index_dtype) in SPSV_NON_TRANS_EXTENDED_COMBOS: + if (data_dtype, index_dtype) in SPSV_NON_TRANS_SUPPORTED_COMBOS: return if data_dtype == torch.bfloat16 and index_dtype == torch.int32: return raise TypeError( f"{fmt_name} SpSV currently supports NON_TRANS combinations: " "(float32, int32/int64), (float64, int32/int64), " -<<<<<<< HEAD - "(complex32, int32/int64), (complex64, int32/int64), (bfloat16, int32)" -======= "(complex64, int32/int64), (complex128, int32/int64), (bfloat16, int32)" ->>>>>>> 5a83e0f (test) ) def _validate_spsv_trans_combo(data_dtype, index_dtype, fmt_name): - if (data_dtype, index_dtype) in SPSV_TRANS_PRIMARY_COMBOS: + if (data_dtype, index_dtype) in SPSV_TRANS_SUPPORTED_COMBOS: return raise TypeError( - f"{fmt_name} SpSV currently supports TRANS combinations with int32 indices only: " -<<<<<<< HEAD - "(float32, int32), (float64, int32), (complex32, int32), (complex64, int32)" -======= - "(float32, int32), (float64, int32), (complex64, int32), (complex128, int32)" ->>>>>>> 5a83e0f (test) + f"{fmt_name} SpSV currently supports TRANS/CONJ combinations: " + "(float32, int32/int64), (float64, int32/int64), " + "(complex64, int32/int64), (complex128, int32/int64)" ) @@ -102,8 +95,11 @@ def _normalize_spsv_transpose_mode(transpose): return "N" if token in ("T", "TRANS"): return "T" + if token in ("C", "H", "CONJ", "CONJ_TRANS", "CONJUGATE_TRANSPOSE"): + return "C" raise ValueError( - "transpose must be bool or one of: N/NON/NON_TRANS, T/TRANS" + "transpose must be bool or one of: " + "N/NON/NON_TRANS, T/TRANS, C/H/CONJ/CONJ_TRANS/CONJUGATE_TRANSPOSE" ) @@ -130,11 +126,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( -<<<<<<< HEAD - "data dtype must be one of: bfloat16, float32, float64, complex32, complex64" -======= "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" ->>>>>>> 5a83e0f (test) ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -174,23 +166,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): ) -<<<<<<< HEAD -def _promote_complex32_spsv_inputs(data, b): - if _is_complex32_dtype(data.dtype): - return data.to(torch.complex64), b.to(torch.complex64), data.dtype - return data, b, None - - -def _restore_complex32_spsv_output(x, target_dtype): - if _is_complex32_dtype(target_dtype): - limit = 65504.0 - real = torch.clamp(x.real, min=-limit, max=limit).to(torch.float16) - imag = torch.clamp(x.imag, min=-limit, max=limit).to(torch.float16) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) - return x.to(target_dtype) - - -======= def _prepare_spsv_working_inputs(data, b): return data, b, None @@ -245,7 +220,40 @@ def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode): ) ->>>>>>> 5a83e0f (test) +def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode): + """Prepare an equivalent CSR triangular system before launching the solve kernel. + + Keep TRANS/CONJ handling outside the Triton solve kernels so the kernels only + execute one fixed CSR triangular solve semantics. + """ + if trans_mode == "N": + kernel_data = data + kernel_indices64 = indices64 + kernel_indptr64 = indptr64 + lower_eff = lower + else: + kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( + data, + indices64, + indptr64, + n_rows, + n_cols, + conjugate=(trans_mode == "C"), + ) + lower_eff = not lower + + levels = _build_spsv_levels( + kernel_indptr64, kernel_indices64, n_rows, lower=lower_eff + ) + return ( + kernel_data, + kernel_indices64, + kernel_indptr64, + lower_eff, + levels, + ) + + @triton.jit def _spsv_csr_level_kernel( data_ptr, @@ -396,6 +404,7 @@ def _spsv_csr_level_kernel_complex( tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + @triton.jit def _spsv_coo_level_kernel_real( data_ptr, @@ -609,9 +618,6 @@ def _triton_spsv_csr_vector_complex( indptr, block_nnz=block_nnz, max_segments=max_segments ) -<<<<<<< HEAD - data_ri = torch.view_as_real(data.contiguous()).reshape(-1).contiguous() -======= # Some PyTorch builds return CSR values with a non-strided layout wrapper. # Materialize a plain 1D strided buffer before splitting into real/imag parts. if data.layout != torch.strided: @@ -621,7 +627,6 @@ def _triton_spsv_csr_vector_complex( data_strided = data.contiguous() data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() ->>>>>>> 5a83e0f (test) b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -656,6 +661,22 @@ def _triton_spsv_csr_vector_complex( return x +def _choose_transpose_family_launch_config(indptr, block_nnz=None, max_segments=None): + if block_nnz is not None or max_segments is not None: + return _auto_spsv_launch_config(indptr, block_nnz=block_nnz, max_segments=max_segments) + + if indptr.numel() <= 1: + return 32, 1 + max_nnz_per_row = int((indptr[1:] - indptr[:-1]).max().item()) + for cand in (32, 64, 128, 256, 512, 1024): + req = max((max_nnz_per_row + cand - 1) // cand, 1) + if req <= 2048: + return cand, req + cand = 2048 + req = max((max_nnz_per_row + cand - 1) // cand, 1) + return cand, req + + def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if not all(torch.is_tensor(t) for t in (data, row, col, b)): raise TypeError("data, row, col, b must all be torch.Tensor") @@ -674,10 +695,6 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): if b.ndim == 2 and b.shape[0] != n_rows: raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") -<<<<<<< HEAD - if data.dtype not in (torch.bfloat16, torch.float32, torch.float64): - raise TypeError("data dtype must be one of: bfloat16, float32, float64") -======= if data.dtype not in ( torch.bfloat16, torch.float32, @@ -688,7 +705,6 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): raise TypeError( "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" ) ->>>>>>> 5a83e0f (test) if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") if row.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: @@ -724,7 +740,7 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): ) -def _csr_transpose(data, indices64, indptr64, n_rows, n_cols): +def _csr_transpose(data, indices64, indptr64, n_rows, n_cols, conjugate=False): if data.numel() == 0: out_data = data out_indices = torch.empty(0, dtype=torch.int64, device=data.device) @@ -737,31 +753,13 @@ def _csr_transpose(data, indices64, indptr64, n_rows, n_cols): ) new_row = indices64 new_col = row_ids + data_eff = data.conj() if conjugate and torch.is_complex(data) else data data_t, indices_t, indptr_t = _coo_to_csr_sorted_unique( - data, new_row, new_col, n_cols, n_rows + data_eff, new_row, new_col, n_cols, n_rows ) return data_t, indices_t, indptr_t -def _csr_reverse_rows_cols(data, indices64, indptr64, n_rows): - if data.numel() == 0: - out_data = data - out_indices = torch.empty(0, dtype=torch.int64, device=data.device) - out_indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) - return out_data, out_indices, out_indptr - - row_ids = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), - indptr64[1:] - indptr64[:-1], - ) - new_row = (n_rows - 1) - row_ids - new_col = (n_rows - 1) - indices64 - data_r, indices_r, indptr_r = _coo_to_csr_sorted_unique( - data, new_row, new_col, n_rows, n_rows - ) - return data_r, indices_r, indptr_r - - def _coo_is_sorted_unique(row64, col64, n_cols): nnz = row64.numel() if nnz <= 1: @@ -877,47 +875,23 @@ def flagsparse_spsv_csr( """Sparse triangular solve using Triton level-scheduling kernels. Primary support matrix: -<<<<<<< HEAD - - NON_TRANS: float32/float64/complex32/complex64 with int32/int64 indices - - TRANS: float32/float64/complex32/complex64 with int32 indices - - bfloat16 remains NON_TRANS + int32 - """ -======= - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices - - TRANS: float32/float64/complex64/complex128 with int32 indices + - TRANS/CONJ: float32/float64/complex64/complex128 with int32/int64 indices - bfloat16 remains NON_TRANS + int32 """ input_data = data input_indices = indices input_indptr = indptr ->>>>>>> 5a83e0f (test) trans_mode = _normalize_spsv_transpose_mode(transpose) data, input_index_dtype, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( data, indices, indptr, b, shape ) original_output_dtype = None -<<<<<<< HEAD - data, b, original_output_dtype = _promote_complex32_spsv_inputs(data, b) -======= - rev_perm = None data, b, original_output_dtype = _prepare_spsv_working_inputs(data, b) ->>>>>>> 5a83e0f (test) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") if trans_mode == "N": _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "CSR") -<<<<<<< HEAD - lower_eff = lower - kernel_data = data - kernel_indices64 = indices - kernel_indptr64 = indptr - else: - _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") - lower_eff = not lower - kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( - data, indices, indptr, n_rows, n_cols - ) -======= else: _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") @@ -926,28 +900,14 @@ def flagsparse_spsv_csr( ) cached = _spsv_cache_get(_SPSV_CSR_PREPROCESS_CACHE, preprocess_key) if cached is None: - if trans_mode == "N": - lower_eff = lower - kernel_data = data - kernel_indices64 = indices - kernel_indptr64 = indptr - rev_perm = None - else: - lower_eff = not lower - kernel_data, kernel_indices64, kernel_indptr64 = _csr_transpose( - data, indices, indptr, n_rows, n_cols - ) - rev_perm = None - levels = _build_spsv_levels( - kernel_indptr64, kernel_indices64, n_rows, lower=lower_eff - ) - cached = ( - kernel_data, - kernel_indices64, - kernel_indptr64, - rev_perm, - lower_eff, - levels, + cached = _prepare_spsv_csr_system( + data, + indices, + indptr, + n_rows, + n_cols, + lower, + trans_mode, ) _spsv_cache_put( _SPSV_CSR_PREPROCESS_CACHE, @@ -955,8 +915,7 @@ def flagsparse_spsv_csr( cached, _SPSV_CSR_PREPROCESS_CACHE_SIZE, ) - kernel_data, kernel_indices64, kernel_indptr64, rev_perm, lower_eff, levels = cached ->>>>>>> 5a83e0f (test) + kernel_data, kernel_indices64, kernel_indptr64, lower_eff, levels = cached kernel_indices = ( kernel_indices64.to(torch.int32) @@ -971,39 +930,37 @@ def flagsparse_spsv_csr( compute_dtype = torch.float32 data_in = kernel_data.to(torch.float32) b_in = b.to(torch.float32) - elif data.dtype == torch.complex64 and trans_mode == "T": + elif data.dtype == torch.complex64 and trans_mode in ("T", "C"): compute_dtype = torch.complex128 data_in = kernel_data.to(torch.complex128) b_in = b.to(torch.complex128) elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: - # Optional high-precision mode; disabled by default for throughput. - compute_dtype = torch.float64 - data_in = kernel_data.to(torch.float64) -<<<<<<< HEAD - b_in = b.to(torch.float64) - elif data.dtype == torch.float32 and trans_mode == "T": compute_dtype = torch.float64 data_in = kernel_data.to(torch.float64) b_in = b.to(torch.float64) - levels = _build_spsv_levels( - kernel_indptr, kernel_indices, n_rows, lower=lower_eff - ) -======= - b_in = b.to(torch.float64) - elif data.dtype == torch.float32 and trans_mode == "T": + elif data.dtype == torch.float32 and trans_mode in ("T", "C"): compute_dtype = torch.float64 data_in = kernel_data.to(torch.float64) b_in = b.to(torch.float64) ->>>>>>> 5a83e0f (test) - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - kernel_indptr, block_nnz=block_nnz, max_segments=max_segments - ) + + is_transpose_family_op = trans_mode != "N" + if is_transpose_family_op: + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) + else: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) + vec_real = _triton_spsv_csr_vector + vec_complex = _triton_spsv_csr_vector_complex + torch.cuda.synchronize() t0 = time.perf_counter() if b_in.ndim == 1: if torch.is_complex(data_in): - x = _triton_spsv_csr_vector_complex( + x = vec_complex( data_in, kernel_indices, kernel_indptr, @@ -1019,7 +976,7 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, ) else: - x = _triton_spsv_csr_vector( + x = vec_real( data_in, kernel_indices, kernel_indptr, @@ -1040,7 +997,7 @@ def flagsparse_spsv_csr( bj = b_in[:, j].contiguous() if torch.is_complex(data_in): cols.append( - _triton_spsv_csr_vector_complex( + vec_complex( data_in, kernel_indices, kernel_indptr, @@ -1058,7 +1015,7 @@ def flagsparse_spsv_csr( ) else: cols.append( - _triton_spsv_csr_vector( + vec_real( data_in, kernel_indices, kernel_indptr, @@ -1077,11 +1034,7 @@ def flagsparse_spsv_csr( x = torch.stack(cols, dim=1) target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype if x.dtype != target_dtype: -<<<<<<< HEAD - x = _restore_complex32_spsv_output(x, target_dtype) -======= x = _restore_spsv_output(x, target_dtype) ->>>>>>> 5a83e0f (test) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: @@ -1096,6 +1049,7 @@ def flagsparse_spsv_csr( def flagsparse_spsv_coo( + data, row, col, @@ -1117,7 +1071,7 @@ def flagsparse_spsv_coo( Notes: - direct mode currently supports only non-transposed real-valued inputs - - complex dtypes and transpose=True always route through the CSR implementation + - complex dtypes and TRANS/CONJ always route through the CSR implementation """ data, row64, col64, b, n_rows, n_cols = _prepare_spsv_coo_inputs( data, row, col, b, shape, transpose=transpose @@ -1130,12 +1084,13 @@ def flagsparse_spsv_coo( raise ValueError("coo_mode must be one of: 'auto', 'direct', 'csr'") sorted_unique = _coo_is_sorted_unique(row64, col64, n_cols) - direct_supported = (not transpose) and (not torch.is_complex(data)) + trans_mode = _normalize_spsv_transpose_mode(transpose) + direct_supported = (trans_mode == "N") and (not torch.is_complex(data)) use_direct = direct_supported and (mode == "direct" or (mode == "auto" and sorted_unique)) if mode == "direct" and not direct_supported: raise ValueError( "coo_mode='direct' supports only non-transposed real-valued inputs; " - "use coo_mode='csr' or 'auto' for transpose or complex dtypes" + "use coo_mode='csr' or 'auto' for TRANS/CONJ or complex dtypes" ) if mode == "direct" and not sorted_unique: raise ValueError( @@ -1147,8 +1102,6 @@ def flagsparse_spsv_coo( data_csr, indices_csr, indptr_csr = _coo_to_csr_sorted_unique( data, row64, col64, n_rows, n_cols ) - if transpose: - indices_csr = indices_csr.to(torch.int32) return flagsparse_spsv_csr( data_csr, indices_csr, diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index d44d7b5..fc9c891 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -16,24 +16,12 @@ pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") -<<<<<<< HEAD -COMPLEX32_DTYPE = getattr(torch, "complex32", None) -if COMPLEX32_DTYPE is None: - COMPLEX32_DTYPE = getattr(torch, "chalf", None) - -SUPPORTED_COMPLEX_DTYPES = [] -if COMPLEX32_DTYPE is not None: - SUPPORTED_COMPLEX_DTYPES.append(COMPLEX32_DTYPE) -SUPPORTED_COMPLEX_DTYPES.append(torch.complex64) - -SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] -======= SUPPORTED_COMPLEX_DTYPES = [torch.complex64, torch.complex128] SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] NON_TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] -TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] ->>>>>>> 5a83e0f (test) +TRANS_CONJ_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +TRANS_CONJ_MODES = ["TRANS", "CONJ"] def _dtype_id(dtype): @@ -41,11 +29,6 @@ def _dtype_id(dtype): def _tol(dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return 5e-3, 5e-3 -======= ->>>>>>> 5a83e0f (test) if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-3 return 1e-10, 1e-8 @@ -54,57 +37,57 @@ def _tol(dtype): def _rand_like(dtype, shape, device): if dtype in (torch.float32, torch.float64): return torch.randn(shape, dtype=dtype, device=device) -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - pair = torch.randn((*shape, 2), dtype=torch.float16, device=device) * 0.1 - return torch.view_as_complex(pair) - base = torch.float32 -======= base = torch.float32 if dtype == torch.complex64 else torch.float64 ->>>>>>> 5a83e0f (test) r = torch.randn(shape, dtype=base, device=device) i = torch.randn(shape, dtype=base, device=device) return torch.complex(r, i) def _ref_dtype(dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return torch.complex64 -======= ->>>>>>> 5a83e0f (test) return dtype def _safe_cast_tensor(tensor, dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - real = tensor.real.to(torch.float16) - imag = tensor.imag.to(torch.float16) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) -======= ->>>>>>> 5a83e0f (test) return tensor.to(dtype) def _cmp_view(tensor, dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return tensor.to(torch.complex64) return tensor -def _build_lower_triangular(n, dtype, device): - off = _rand_like(dtype, (n, n), device) * 0.02 - A = torch.tril(off) -======= - return tensor +def _apply_ref_op(A, op_mode): + if op_mode == "TRANS": + return A.transpose(-2, -1) + if op_mode == "CONJ": + return A.transpose(-2, -1).conj() if torch.is_complex(A) else A.transpose(-2, -1) + return A + + +def _effective_upper(lower, op_mode): + return lower if op_mode in ("TRANS", "CONJ") else not lower + + +def _effective_lower_for_op(lower, op_mode): + return (not lower) if op_mode in ("TRANS", "CONJ") else lower + + +def _transpose_arg(op_mode): + if op_mode == "NON": + return False + return op_mode + + +def _cupy_apply_op(A_cp, op_mode): + if op_mode == "TRANS": + return A_cp.transpose().tocsr() + if op_mode == "CONJ": + return A_cp.transpose().conj().tocsr() + return A_cp def _build_triangular(n, dtype, device, lower=True): off = _rand_like(dtype, (n, n), device) * 0.02 A = torch.tril(off) if lower else torch.triu(off) ->>>>>>> 5a83e0f (test) if torch.is_complex(A): diag = (torch.rand(n, device=device, dtype=A.real.dtype) + 2.0).to(A.real.dtype) A = A + torch.diag(torch.complex(diag, torch.zeros_like(diag))) @@ -117,12 +100,7 @@ def _build_triangular(n, dtype, device, lower=True): def _cupy_csr_from_torch(data, indices, indptr, shape): if cp is None or cpx_sparse is None: return None -<<<<<<< HEAD - data_ref = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data - data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) -======= data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) ->>>>>>> 5a83e0f (test) idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous())) ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous())) return cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) @@ -131,18 +109,9 @@ def _cupy_csr_from_torch(data, indices, indptr, shape): def _cupy_ref_spsv(A_cp, b_t, *, lower, unit_diagonal=False): if cp is None or cpx_spsolve_triangular is None: return None -<<<<<<< HEAD - b_ref = b_t.to(torch.complex64) if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE else b_t - b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) - x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) - x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) - if COMPLEX32_DTYPE is not None and b_t.dtype == COMPLEX32_DTYPE: - return x_t.to(torch.complex64) -======= b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_t.contiguous())) x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) ->>>>>>> 5a83e0f (test) return x_t.to(b_t.dtype) @@ -183,19 +152,11 @@ def test_spsv_csr_lower_matches_dense(n, dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) -<<<<<<< HEAD -@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) -@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) -def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): - device = torch.device("cuda") - A = _build_lower_triangular(n, dtype, device) -======= @pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) @pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) ->>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) x_ref = torch.linalg.solve_triangular( A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False @@ -222,28 +183,23 @@ def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) -<<<<<<< HEAD -@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) -def test_spsv_csr_trans_int32_supported_combos(n, dtype): - device = torch.device("cuda") - A = _build_lower_triangular(n, dtype, device) -======= -@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) -def test_spsv_csr_trans_int32_supported_combos(n, dtype): +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_transpose_family_supported_combos(n, dtype, index_dtype, op_mode): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) ->>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) A_ref = A.to(_ref_dtype(dtype)) b_ref = b.to(_ref_dtype(dtype)) x_ref = torch.linalg.solve_triangular( - A_ref.transpose(-2, -1), b_ref.unsqueeze(-1), upper=True + _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(True, op_mode) ).squeeze(-1) Asp = A.to_sparse_csr() data = Asp.values() - indices = Asp.col_indices().to(torch.int32) - indptr = Asp.crow_indices().to(torch.int32) + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) x = flagsparse_spsv_csr( data, @@ -253,21 +209,13 @@ def test_spsv_csr_trans_int32_supported_combos(n, dtype): (n, n), lower=True, unit_diagonal=False, - transpose=True, + transpose=_transpose_arg(op_mode), ) rtol, atol = _tol(dtype) assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) @pytest.mark.spsv -<<<<<<< HEAD -@pytest.mark.skipif(cp is None or cpx_spsolve_triangular is None, reason="CuPy/cuSPARSE required") -@pytest.mark.parametrize("n", SPSV_N) -@pytest.mark.parametrize("dtype", SUPPORTED_DTYPES, ids=_dtype_id) -def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): - device = torch.device("cuda") - A = _build_lower_triangular(n, dtype, device) -======= @pytest.mark.skipif( cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, reason="CuPy/cuSPARSE required", @@ -277,7 +225,6 @@ def test_spsv_csr_matches_cusparse_non_trans_and_trans(n, dtype): def test_spsv_csr_matches_cusparse_non_trans(n, dtype): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) ->>>>>>> 5a83e0f (test) b = _rand_like(dtype, (n,), device) Asp = A.to_sparse_csr() @@ -291,8 +238,6 @@ def test_spsv_csr_matches_cusparse_non_trans(n, dtype): ) x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) -<<<<<<< HEAD -======= rtol, atol = _tol(dtype) assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) @@ -303,29 +248,38 @@ def test_spsv_csr_matches_cusparse_non_trans(n, dtype): reason="CuPy/cuSPARSE required", ) @pytest.mark.parametrize("n", SPSV_N) -@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) -def test_spsv_csr_matches_cusparse_trans(n, dtype): +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_matches_cusparse_transpose_family(n, dtype, index_dtype, op_mode): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) b = _rand_like(dtype, (n,), device) Asp = A.to_sparse_csr() data = Asp.values() - indices = Asp.col_indices().to(torch.int32) - indptr = Asp.crow_indices().to(torch.int32) + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) ->>>>>>> 5a83e0f (test) x_trans = flagsparse_spsv_csr( - data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=True + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_trans_ref = _cupy_ref_spsv( + _cupy_apply_op(A_cp, op_mode), + b, + lower=_effective_lower_for_op(True, op_mode), + unit_diagonal=False, ) - x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=False, unit_diagonal=False) rtol, atol = _tol(dtype) -<<<<<<< HEAD - assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) - assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) -======= assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) @@ -362,21 +316,23 @@ def test_spsv_csr_non_trans_upper_supported_combos(n, dtype, index_dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) -@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) -def test_spsv_csr_trans_upper_int32_supported_combos(n, dtype): +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_upper_transpose_family_supported_combos(n, dtype, index_dtype, op_mode): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=False) b = _rand_like(dtype, (n,), device) A_ref = A.to(_ref_dtype(dtype)) b_ref = b.to(_ref_dtype(dtype)) x_ref = torch.linalg.solve_triangular( - A_ref.transpose(-2, -1), b_ref.unsqueeze(-1), upper=False + _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(False, op_mode) ).squeeze(-1) Asp = A.to_sparse_csr() data = Asp.values() - indices = Asp.col_indices().to(torch.int32) - indptr = Asp.crow_indices().to(torch.int32) + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) x = flagsparse_spsv_csr( data, @@ -386,7 +342,7 @@ def test_spsv_csr_trans_upper_int32_supported_combos(n, dtype): (n, n), lower=False, unit_diagonal=False, - transpose=True, + transpose=_transpose_arg(op_mode), ) rtol, atol = _tol(dtype) assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) @@ -425,22 +381,36 @@ def test_spsv_csr_matches_cusparse_upper_non_trans(n, dtype): reason="CuPy/cuSPARSE required", ) @pytest.mark.parametrize("n", SPSV_N) -@pytest.mark.parametrize("dtype", TRANS_DTYPES, ids=_dtype_id) -def test_spsv_csr_matches_cusparse_upper_trans(n, dtype): +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_matches_cusparse_upper_transpose_family(n, dtype, index_dtype, op_mode): device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=False) b = _rand_like(dtype, (n,), device) Asp = A.to_sparse_csr() data = Asp.values() - indices = Asp.col_indices().to(torch.int32) - indptr = Asp.crow_indices().to(torch.int32) + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) x_trans = flagsparse_spsv_csr( - data, indices, indptr, b, (n, n), lower=False, unit_diagonal=False, transpose=True + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_trans_ref = _cupy_ref_spsv( + _cupy_apply_op(A_cp, op_mode), + b, + lower=_effective_lower_for_op(False, op_mode), + unit_diagonal=False, ) - x_trans_ref = _cupy_ref_spsv(A_cp.transpose().tocsr(), b, lower=True, unit_diagonal=False) rtol, atol = _tol(dtype) assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) @@ -448,13 +418,14 @@ def test_spsv_csr_matches_cusparse_upper_trans(n, dtype): @pytest.mark.spsv @pytest.mark.parametrize("n", SPSV_N) -def test_spsv_coo_transpose_complex128_routes_through_csr(n): +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_coo_transpose_family_complex128_routes_through_csr(n, op_mode): device = torch.device("cuda") dtype = torch.complex128 A = _build_triangular(n, dtype, device, lower=True) b = _rand_like(dtype, (n,), device) x_ref = torch.linalg.solve_triangular( - A.transpose(-2, -1), b.unsqueeze(-1), upper=True + _apply_ref_op(A, op_mode), b.unsqueeze(-1), upper=_effective_upper(True, op_mode) ).squeeze(-1) A_coo = A.to_sparse_coo().coalesce() @@ -469,9 +440,8 @@ def test_spsv_coo_transpose_complex128_routes_through_csr(n): (n, n), lower=True, unit_diagonal=False, - transpose=True, + transpose=_transpose_arg(op_mode), coo_mode="auto", ) rtol, atol = _tol(dtype) assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) ->>>>>>> 5a83e0f (test) diff --git a/tests/test_scatter.py b/tests/test_scatter.py index e92b41e..604adeb 100644 --- a/tests/test_scatter.py +++ b/tests/test_scatter.py @@ -54,7 +54,6 @@ def _parse_value_dtypes(raw): "bfloat16", "float32", "float64", - "complex32", "complex64", "complex128", } diff --git a/tests/test_spsv.py b/tests/test_spsv.py index cdc7a31..2737db2 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -29,21 +29,6 @@ ITERS = 20 DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB -FLOAT16_LIMIT = 65504.0 -COMPLEX32_DTYPE = getattr(torch, "complex32", None) -if COMPLEX32_DTYPE is None: - COMPLEX32_DTYPE = getattr(torch, "chalf", None) - -# CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) -CSR_FULL_VALUE_DTYPES = [ - torch.float32, - torch.float64, -] -if COMPLEX32_DTYPE is not None: - CSR_FULL_VALUE_DTYPES.append(COMPLEX32_DTYPE) -CSR_FULL_VALUE_DTYPES.append(torch.complex64) -CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] - # CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) CSR_FULL_VALUE_DTYPES = [ torch.float32, @@ -52,12 +37,53 @@ torch.complex128, ] CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] +SPSV_OP_MODES = ["NON", "TRANS", "CONJ"] def _dtype_name(dtype): return str(dtype).replace("torch.", "") +VALUE_DTYPE_NAME_MAP = { + _dtype_name(dtype): dtype for dtype in CSR_FULL_VALUE_DTYPES +} +VALUE_DTYPE_NAME_MAP.update({ + "float": torch.float32, + "double": torch.float64, +}) +INDEX_DTYPE_NAME_MAP = { + _dtype_name(dtype): dtype for dtype in CSR_FULL_INDEX_DTYPES +} + + +def _parse_csv_tokens(raw): + return [tok.strip() for tok in str(raw).split(",") if tok.strip()] + + +def _parse_value_dtypes_filter(raw): + tokens = [tok.lower() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in VALUE_DTYPE_NAME_MAP] + if invalid: + raise ValueError(f"unsupported value dtypes: {invalid}") + return [VALUE_DTYPE_NAME_MAP[tok] for tok in tokens] + + +def _parse_index_dtypes_filter(raw): + tokens = [tok.lower() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in INDEX_DTYPE_NAME_MAP] + if invalid: + raise ValueError(f"unsupported index dtypes: {invalid}") + return [INDEX_DTYPE_NAME_MAP[tok] for tok in tokens] + + +def _parse_op_modes_filter(raw): + tokens = [tok.upper() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in SPSV_OP_MODES] + if invalid: + raise ValueError(f"unsupported ops: {invalid}") + return tokens + + def _fmt_ms(v): return "N/A" if v is None else f"{v:.4f}" @@ -73,11 +99,6 @@ def _fmt_err(v): def _tol_for_dtype(dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return 5e-3, 5e-3 -======= ->>>>>>> 5a83e0f (test) if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 return 1e-12, 1e-10 @@ -86,67 +107,25 @@ def _tol_for_dtype(dtype): def _randn_by_dtype(n, dtype, device): if dtype in (torch.float32, torch.float64): return torch.randn(n, dtype=dtype, device=device) -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - pair = torch.randn((n, 2), dtype=torch.float16, device=device) * 0.1 - return torch.view_as_complex(pair) - base = torch.float32 -======= base = torch.float32 if dtype == torch.complex64 else torch.float64 ->>>>>>> 5a83e0f (test) real = torch.randn(n, dtype=base, device=device) imag = torch.randn(n, dtype=base, device=device) return torch.complex(real, imag) def _dense_ref_dtype(dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - return torch.complex64 -======= ->>>>>>> 5a83e0f (test) return dtype def _tensor_from_scalar_values(values, dtype, device): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - real = torch.clamp( - torch.tensor(values, dtype=torch.float32, device=device), - min=-FLOAT16_LIMIT, - max=FLOAT16_LIMIT, - ).to(torch.float16) - imag = torch.zeros_like(real) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) -======= ->>>>>>> 5a83e0f (test) return torch.tensor(values, dtype=dtype, device=device) def _safe_cast_tensor(tensor, dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and dtype == COMPLEX32_DTYPE: - real = torch.clamp(tensor.real, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) - imag = torch.clamp(tensor.imag, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) -======= ->>>>>>> 5a83e0f (test) return tensor.to(dtype) def _cast_real_tensor_to_value_dtype(values, value_dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: - real = torch.clamp(values, min=-FLOAT16_LIMIT, max=FLOAT16_LIMIT).to(torch.float16) - imag = torch.zeros_like(real) - return torch.view_as_complex(torch.stack([real, imag], dim=-1).contiguous()) - return values.to(value_dtype) - - -def _cupy_ref_inputs(data, b): - if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: - return data.to(torch.complex64), b.to(torch.complex64) -======= return values.to(value_dtype) @@ -166,6 +145,9 @@ def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): if op_mode == "TRANS": A_eff = A.transpose(0, 1) upper = lower + elif op_mode == "CONJ": + A_eff = A.transpose(0, 1).conj() if torch.is_complex(A) else A.transpose(0, 1) + upper = lower else: A_eff = A upper = not lower @@ -175,16 +157,10 @@ def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): def _cupy_ref_inputs(data, b): ->>>>>>> 5a83e0f (test) return data, b def _compare_view(tensor, value_dtype): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: - return tensor.to(torch.complex64) -======= ->>>>>>> 5a83e0f (test) return tensor @@ -192,9 +168,9 @@ def _supported_csr_full_ops(value_dtype, index_dtype): if value_dtype not in CSR_FULL_VALUE_DTYPES: return [] if index_dtype == torch.int32: - return ["NON", "TRANS"] + return ["NON", "TRANS", "CONJ"] if index_dtype == torch.int64: - return ["NON"] + return ["NON", "TRANS", "CONJ"] return [] @@ -215,17 +191,10 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True base_real_dtype = torch.float32 elif value_dtype == torch.float64: base_real_dtype = torch.float64 -<<<<<<< HEAD - elif COMPLEX32_DTYPE is not None and value_dtype == COMPLEX32_DTYPE: - base_real_dtype = torch.float16 - else: - base_real_dtype = torch.float32 -======= elif value_dtype == torch.complex64: base_real_dtype = torch.float32 else: base_real_dtype = torch.float64 ->>>>>>> 5a83e0f (test) for i in range(n): if lower: @@ -272,14 +241,7 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) -<<<<<<< HEAD - vals_t = _cast_real_tensor_to_value_dtype( - torch.tensor(vals_host, dtype=base_real_dtype, device=device), - value_dtype, - ) -======= vals_t = torch.tensor(vals_host, dtype=value_dtype, device=device) ->>>>>>> 5a83e0f (test) order = torch.argsort(rows_t * max(1, n) + cols_t) rows_t = rows_t[order] cols_t = cols_t[order] @@ -293,16 +255,15 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True def _csr_to_dense(data, indices, indptr, shape): n_rows, n_cols = shape - coo_data = data.to(torch.complex64) if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE else data row_ind = torch.repeat_interleave( - torch.arange(n_rows, device=coo_data.device, dtype=torch.int64), + torch.arange(n_rows, device=data.device, dtype=torch.int64), indptr[1:] - indptr[:-1], ) coo = torch.sparse_coo_tensor( torch.stack([row_ind, indices.to(torch.int64)]), - coo_data, + data, (n_rows, n_cols), - device=coo_data.device, + device=data.device, ).coalesce() return coo.to_dense() @@ -317,7 +278,7 @@ def _csr_to_coo(data, indices, indptr, shape): return data, row, col -def _csr_transpose(data, indices, indptr, shape): +def _csr_transpose(data, indices, indptr, shape, conjugate=False): n_rows, n_cols = int(shape[0]), int(shape[1]) if data.numel() == 0: return ( @@ -337,18 +298,15 @@ def _csr_transpose(data, indices, indptr, shape): row_t = row_t[order] col_t = col_t[order] - data_t = data[order] + data_eff = data.conj() if conjugate and torch.is_complex(data) else data + data_t = data_eff[order] nnz_per_row = torch.bincount(row_t, minlength=n_cols) indptr_t = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) indptr_t[1:] = torch.cumsum(nnz_per_row, dim=0) return data_t, col_t.to(torch.int64), indptr_t -<<<<<<< HEAD -def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): -======= def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None, lower=True): ->>>>>>> 5a83e0f (test) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open(file_path, "r", encoding="utf-8") as f: @@ -449,29 +407,6 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and data.dtype == COMPLEX32_DTYPE: - data_ref = data.to(torch.complex64) - x_ref = x_true.to(torch.complex64) - if op_mode == "NON": - b_ref, _ = fs.flagsparse_spmv_csr( - data_ref, indices, indptr, x_ref, shape, return_time=True - ) - return _safe_cast_tensor(b_ref, x_true.dtype) - if op_mode == "TRANS": - data_t, indices_t, indptr_t = _csr_transpose(data_ref, indices, indptr, shape) - b_ref, _ = fs.flagsparse_spmv_csr( - data_t, - indices_t.to(indices.dtype), - indptr_t.to(indptr.dtype), - x_ref, - (shape[1], shape[0]), - return_time=True, - ) - return _safe_cast_tensor(b_ref, x_true.dtype) - raise ValueError("op_mode must be 'NON' or 'TRANS'") -======= ->>>>>>> 5a83e0f (test) if op_mode == "NON": b, _ = fs.flagsparse_spmv_csr( data, indices, indptr, x_true, shape, return_time=True @@ -488,7 +423,41 @@ def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): return_time=True, ) return b - raise ValueError("op_mode must be 'NON' or 'TRANS'") + if op_mode == "CONJ": + data_h, indices_h, indptr_h = _csr_transpose( + data, indices, indptr, shape, conjugate=True + ) + b, _ = fs.flagsparse_spmv_csr( + data_h, + indices_h.to(indices.dtype), + indptr_h.to(indptr.dtype), + x_true, + (shape[1], shape[0]), + return_time=True, + ) + return b + raise ValueError("op_mode must be 'NON', 'TRANS', or 'CONJ'") + + +def _known_solution_metrics(data, indices, indptr, shape, x, x_true, b, value_dtype, op_mode): + atol, rtol = _tol_for_dtype(value_dtype) + x_cmp = _compare_view(x, value_dtype) + x_true_cmp = _compare_view(x_true, value_dtype) + err_x = ( + float(torch.max(torch.abs(x_cmp - x_true_cmp)).item()) + if x.numel() > 0 + else 0.0 + ) + ok_x = torch.allclose(x_cmp, x_true_cmp, atol=atol, rtol=rtol) + + b_recon = _build_rhs_for_csr_op(data, indices, indptr, x, shape, op_mode) + err_res = ( + float(torch.max(torch.abs(b_recon - b)).item()) + if b.numel() > 0 + else 0.0 + ) + ok_res = torch.allclose(b_recon, b, atol=atol, rtol=rtol) + return err_x, ok_x, err_res, ok_res def _cupy_spsolve_lower_csr_or_coo( @@ -546,24 +515,13 @@ def _cupy_spsolve_lower_csr_or_coo( t1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(t0, t1) / iters x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) -<<<<<<< HEAD - if COMPLEX32_DTYPE is not None and b.dtype == COMPLEX32_DTYPE: - x_cu_t = x_cu_t.to(torch.complex64) - else: - x_cu_t = x_cu_t.to(b.dtype) -======= x_cu_t = x_cu_t.to(b.dtype) ->>>>>>> 5a83e0f (test) return cupy_ms, x_cu_t except Exception: return None, None -<<<<<<< HEAD -def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode): -======= def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): ->>>>>>> 5a83e0f (test) if ( cp is None or cpx_sparse is None @@ -583,17 +541,13 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) if op_mode == "TRANS": A_eff = A_cp.transpose().tocsr() -<<<<<<< HEAD - lower_eff = False - else: - A_eff = A_cp - lower_eff = True -======= + lower_eff = not lower + elif op_mode == "CONJ": + A_eff = A_cp.transpose().conj().tocsr() lower_eff = not lower else: A_eff = A_cp lower_eff = lower ->>>>>>> 5a83e0f (test) for _ in range(WARMUP): _ = cpx_spsolve_triangular( @@ -616,11 +570,7 @@ def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): return None, None -<<<<<<< HEAD -def run_spsv_synthetic_all(): -======= def run_spsv_synthetic_all(lower=True): ->>>>>>> 5a83e0f (test) if not torch.cuda.is_available(): print("CUDA is not available. Please run on a GPU-enabled system.") return @@ -680,7 +630,7 @@ def run_spsv_synthetic_all(lower=True): b, shape, lower=lower, - transpose=(op_mode == "TRANS"), + transpose=op_mode, return_time=True, ) else: @@ -791,7 +741,7 @@ def _run_one_csv_row_csr(path, value_dtype, index_dtype, device, lower=True): ) return _finalize_csv_row( path, value_dtype, index_dtype, data, indices, indptr, shape, - x, t_ms, b, n_rows, n_cols, lower=lower, + x, t_ms, b, x_true, n_rows, n_cols, lower=lower, ) @@ -829,6 +779,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower x, t_ms, b, + x_true, n_rows, n_cols, lower=lower, @@ -850,6 +801,7 @@ def _finalize_csv_row( x, t_ms, b, + x_true, n_rows, n_cols, *, @@ -860,6 +812,9 @@ def _finalize_csv_row( cupy_coo_col=None, ): atol, rtol = _tol_for_dtype(value_dtype) + err_x, ok_x, err_res, ok_res = _known_solution_metrics( + data, indices, indptr, shape, x, x_true, b, value_dtype, "NON" + ) pytorch_ms = None err_pt = None ok_pt = False @@ -876,13 +831,9 @@ def _finalize_csv_row( e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() -<<<<<<< HEAD - x_ref = torch.linalg.solve(A_ref, b_ref.unsqueeze(1)).squeeze(1) -======= x_ref = _triangular_solve_reference( A_ref, b_ref, lower=lower, op_mode="NON" ) ->>>>>>> 5a83e0f (test) x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() @@ -1000,21 +951,17 @@ def _finalize_csv_row( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, + "err_x": err_x, + "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, } return row, pt_skip_reason -<<<<<<< HEAD -def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device): - data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device -======= def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, device=device, lower=lower ->>>>>>> 5a83e0f (test) ) indices = indices.to(index_dtype) indptr = indptr.to(index_dtype) @@ -1027,12 +974,8 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l indptr, b, shape, -<<<<<<< HEAD - lower=True, -======= lower=lower, ->>>>>>> 5a83e0f (test) - transpose=(op_mode == "TRANS"), + transpose=op_mode, return_time=True, ) return _finalize_csv_row_csr_full( @@ -1047,12 +990,10 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l x, t_ms, b, + x_true, n_rows, n_cols, -<<<<<<< HEAD -======= lower=lower, ->>>>>>> 5a83e0f (test) ) @@ -1068,14 +1009,15 @@ def _finalize_csv_row_csr_full( x, t_ms, b, + x_true, n_rows, n_cols, -<<<<<<< HEAD -======= lower=True, ->>>>>>> 5a83e0f (test) ): atol, rtol = _tol_for_dtype(value_dtype) + err_x, ok_x, err_res, ok_res = _known_solution_metrics( + data, indices, indptr, shape, x, x_true, b, value_dtype, op_mode + ) pytorch_ms = None err_pt = None @@ -1086,21 +1028,13 @@ def _finalize_csv_row_csr_full( A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr.to(torch.int64), shape ).to(_dense_ref_dtype(value_dtype)) -<<<<<<< HEAD - A_ref = A_dense.transpose(0, 1) if op_mode == "TRANS" else A_dense -======= ->>>>>>> 5a83e0f (test) e0 = torch.cuda.Event(True) e1 = torch.cuda.Event(True) torch.cuda.synchronize() e0.record() -<<<<<<< HEAD - x_ref = torch.linalg.solve(A_ref, b.to(A_ref.dtype).unsqueeze(1)).squeeze(1) -======= x_ref = _triangular_solve_reference( A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode ) ->>>>>>> 5a83e0f (test) x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) e1.record() @@ -1127,11 +1061,7 @@ def _finalize_csv_row_csr_full( ok_cu = False x_cu_t = None cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( -<<<<<<< HEAD - data, indices, indptr, shape, b, op_mode -======= data, indices, indptr, shape, b, op_mode, lower ->>>>>>> 5a83e0f (test) ) if x_cu_t is not None: x_cmp = _compare_view(x, value_dtype) @@ -1160,134 +1090,37 @@ def _finalize_csv_row_csr_full( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, + "err_x": err_x, + "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, } return row, pt_skip_reason -<<<<<<< HEAD -def run_all_supported_spsv_csr_csv(mtx_paths, csv_path): - if not torch.cuda.is_available(): - print("CUDA is not available.") - return - device = torch.device("cuda") - rows_out = [] - for value_dtype in CSR_FULL_VALUE_DTYPES: - for index_dtype in CSR_FULL_INDEX_DTYPES: - op_modes = _supported_csr_full_ops(value_dtype, index_dtype) - for op_mode in op_modes: - print("=" * 150) - print( - f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | opA={op_mode}" - ) - print( - "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." - ) - print( - "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if either error within tolerance." - ) - print("-" * 150) - print( - f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " - f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" - ) - print("-" * 150) - for path in mtx_paths: - try: - row, pt_skip = _run_one_csv_row_csr_full( - path, value_dtype, index_dtype, op_mode, device - ) - rows_out.append(row) - name = os.path.basename(path)[:27] - if len(os.path.basename(path)) > 27: - name = name + "…" - n_rows, n_cols = row["n_rows"], row["n_cols"] - nnz = row["nnz"] - t_ms = row["triton_ms"] - cupy_ms = row["cusparse_ms"] - pytorch_ms = row["pytorch_ms"] - err_pt, err_cu = row["err_pt"], row["err_cu"] - status = row["status"] - print( - f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " - f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " - f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" - ) - if pt_skip: - print(f" NOTE: {pt_skip}") - except Exception as e: - err_msg = str(e) - status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" - rows_out.append( - { - "matrix": os.path.basename(path), - "value_dtype": _dtype_name(value_dtype), - "index_dtype": _dtype_name(index_dtype), - "opA": op_mode, - "n_rows": "ERR", - "n_cols": "ERR", - "nnz": "ERR", - "triton_ms": None, - "pytorch_ms": None, - "cusparse_ms": None, - "csc_ms": None, - "status": status, - "err_pt": None, - "err_cu": None, - } - ) - name = os.path.basename(path)[:27] - if len(os.path.basename(path)) > 27: - name = name + "…" - print( - f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " - f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " - f"{'N/A':>7} {'N/A':>7} " - f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" - ) - print(f" {status}: {e}") - print("-" * 150) - fieldnames = [ - "matrix", - "value_dtype", - "index_dtype", - "opA", - "n_rows", - "n_cols", - "nnz", - "triton_ms", - "pytorch_ms", - "cusparse_ms", - "csc_ms", - "status", - "err_pt", - "err_cu", - ] - with open(csv_path, "w", newline="", encoding="utf-8") as f: - w = csv.DictWriter(f, fieldnames=fieldnames) - w.writeheader() - for r in rows_out: - w.writerow(r) - print(f"Wrote {len(rows_out)} rows to {csv_path}") - - -def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto"): -======= -def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): ->>>>>>> 5a83e0f (test) +def run_all_supported_spsv_csr_csv( + mtx_paths, + csv_path, + lower=True, + value_dtypes=None, + index_dtypes=None, + op_modes=None, +): if not torch.cuda.is_available(): print("CUDA is not available.") return device = torch.device("cuda") rows_out = [] - for value_dtype in CSR_FULL_VALUE_DTYPES: - for index_dtype in CSR_FULL_INDEX_DTYPES: - op_modes = _supported_csr_full_ops(value_dtype, index_dtype) - for op_mode in op_modes: + selected_value_dtypes = value_dtypes or CSR_FULL_VALUE_DTYPES + selected_index_dtypes = index_dtypes or CSR_FULL_INDEX_DTYPES + selected_op_modes = op_modes or SPSV_OP_MODES + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + supported_op_modes = [ + op for op in _supported_csr_full_ops(value_dtype, index_dtype) + if op in selected_op_modes + ] + for op_mode in supported_op_modes: print("=" * 150) print( f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" @@ -1296,14 +1129,15 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." ) print( + "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if either error within tolerance." + "PASS if PyTorch / cuSPARSE reference passes. x_true / residual are diagnostics only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: @@ -1320,13 +1154,14 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] + err_x, err_res = row["err_x"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_x):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -1347,6 +1182,8 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): "cusparse_ms": None, "csc_ms": None, "status": status, + "err_x": None, + "err_res": None, "err_pt": None, "err_cu": None, } @@ -1358,7 +1195,7 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " f"{'N/A':>7} {'N/A':>7} " - f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" ) print(f" {status}: {e}") print("-" * 150) @@ -1375,6 +1212,8 @@ def run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=True): "cusparse_ms", "csc_ms", "status", + "err_x", + "err_res", "err_pt", "err_cu", ] @@ -1416,14 +1255,15 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." ) print( + "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if either error within tolerance." + "PASS if PyTorch / cuSPARSE reference passes. x_true / residual are diagnostics only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " f"{'FlagSparse(ms)':>10} {cu_col:>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{fs_cu_hdr:>7} {'FS/PT':>7} {'Status':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{fs_cu_hdr:>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: @@ -1445,13 +1285,14 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] + err_x, err_res = row["err_x"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_x):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -1471,6 +1312,8 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", "cusparse_ms": None, "csc_ms": None, "status": status, + "err_x": None, + "err_res": None, "err_pt": None, "err_cu": None, } @@ -1481,7 +1324,7 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", print( f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>11} " - f"{'N/A':>7} {'N/A':>7} {status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10}" + f"{'N/A':>7} {'N/A':>7} {status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" ) print(f" {status}: {e}") print("-" * 150) @@ -1497,6 +1340,8 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", "cusparse_ms", "csc_ms", "status", + "err_x", + "err_res", "err_pt", "err_cu", ] @@ -1546,6 +1391,24 @@ def main(): action="store_true", help="Use upper-triangular inputs instead of the default lower-triangular inputs", ) + parser.add_argument( + "--ops", + type=str, + default=None, + help="Comma-separated opA filter for CSR CSV, e.g. TRANS,CONJ", + ) + parser.add_argument( + "--value-dtypes", + type=str, + default=None, + help="Comma-separated value dtype filter for CSR CSV, e.g. float,double,complex64,complex128", + ) + parser.add_argument( + "--index-dtypes", + type=str, + default=None, + help="Comma-separated index dtype filter for CSR CSV, e.g. int32,int64", + ) args = parser.parse_args() lower = not args.upper @@ -1565,11 +1428,29 @@ def main(): if not paths: print("No .mtx files found for --csv-csr") return -<<<<<<< HEAD - run_all_supported_spsv_csr_csv(paths, args.csv_csr) -======= - run_all_supported_spsv_csr_csv(paths, args.csv_csr, lower=lower) ->>>>>>> 5a83e0f (test) + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + op_modes = ( + _parse_op_modes_filter(args.ops) + if args.ops + else None + ) + run_all_supported_spsv_csr_csv( + paths, + args.csv_csr, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + op_modes=op_modes, + ) return if args.csv_coo: if not paths: From c8768642922531701cbbf7771dc50559d48b1e31 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 21 Apr 2026 11:38:57 +0800 Subject: [PATCH 04/22] test --- tests/test_spsv.py | 119 +++++++++++++++++++++++---------------------- 1 file changed, 60 insertions(+), 59 deletions(-) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 2737db2..b64e6ff 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -8,9 +8,16 @@ import csv import glob import os +import sys +from pathlib import Path import torch +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + import flagsparse as fs try: @@ -268,13 +275,13 @@ def _csr_to_dense(data, indices, indptr, shape): return coo.to_dense() -def _csr_to_coo(data, indices, indptr, shape): +def _csr_to_coo(data, indices, indptr, shape, index_dtype=torch.int64): n_rows = int(shape[0]) row = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), + torch.arange(n_rows, device=data.device, dtype=index_dtype), indptr[1:] - indptr[:-1], ) - col = indices.to(torch.int64) + col = indices.to(index_dtype) return data, row, col @@ -387,9 +394,11 @@ def _accum(r, c, v): return data, indices, indptr, (n_rows, n_cols) -def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode): +def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode, index_dtype=torch.int64): """Sorted COO from CSR; optional shuffle/duplicate for csr|auto (与原先 CSV 行为一致).""" - data_c, row_c, col_c = _csr_to_coo(data, indices, indptr, shape) + data_c, row_c, col_c = _csr_to_coo( + data, indices, indptr, shape, index_dtype=index_dtype + ) if coo_mode in ("csr", "auto"): if data_c.numel() == 0: return data_c, row_c, col_c @@ -634,7 +643,9 @@ def run_spsv_synthetic_all(lower=True): return_time=True, ) else: - dc, rr, cc = _csr_to_coo(data, indices, indptr, shape) + dc, rr, cc = _csr_to_coo( + data, indices, indptr, shape, index_dtype=index_dtype + ) x, t_ms = fs.flagsparse_spsv_coo( dc, rr, @@ -726,25 +737,6 @@ def run_spsv_synthetic_all(lower=True): print(sep) -def _run_one_csv_row_csr(path, value_dtype, index_dtype, device, lower=True): - data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device, lower=lower - ) - indices = indices.to(index_dtype) - n_rows, n_cols = shape - x_true = _randn_by_dtype(n_rows, value_dtype, device) - b, _ = fs.flagsparse_spmv_csr( - data, indices, indptr, x_true, shape, return_time=True - ) - x, t_ms = fs.flagsparse_spsv_csr( - data, indices, indptr, b, shape, lower=lower, return_time=True - ) - return _finalize_csv_row( - path, value_dtype, index_dtype, data, indices, indptr, shape, - x, t_ms, b, x_true, n_rows, n_cols, lower=lower, - ) - - def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower=True): data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, device=device, lower=lower @@ -756,7 +748,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower data, indices, indptr, x_true, shape, return_time=True ) d_in, r_in, c_in = _coo_inputs_for_csv( - data, indices, indptr, shape, coo_mode + data, indices, indptr, shape, coo_mode, index_dtype=index_dtype ) x, t_ms = fs.flagsparse_spsv_coo( d_in, @@ -1225,35 +1217,32 @@ def run_all_supported_spsv_csr_csv( print(f"Wrote {len(rows_out)} rows to {csv_path}") -def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", lower=True): +def run_all_dtypes_spsv_coo_csv( + mtx_paths, + csv_path, + coo_mode="auto", + lower=True, + value_dtypes=None, + index_dtypes=None, +): if not torch.cuda.is_available(): print("CUDA is not available.") return - if not use_coo: - run_all_supported_spsv_csr_csv(mtx_paths, csv_path, lower=lower) - return device = torch.device("cuda") rows_out = [] - label = "COO" if use_coo else "CSR" - cu_col = "COO(ms)" if use_coo else "CSR(ms)" - fs_cu_hdr = "FS/COO" if use_coo else "FS/CSR" - for value_dtype in VALUE_DTYPES: - for index_dtype in INDEX_DTYPES: - atol, rtol = _tol_for_dtype(value_dtype) + selected_value_dtypes = value_dtypes or VALUE_DTYPES + selected_index_dtypes = index_dtypes or INDEX_DTYPES + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: print("=" * 150) print( - f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | {label}" - + (f" triA={'LOWER' if lower else 'UPPER'}" if not use_coo else f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}") + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | COO" + f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}" + ) + print( + "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " + "b 由 CSR SpMV 构造,与 CSR 测试一致。" ) - if use_coo: - print( - "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " - "b 由 CSR SpMV 构造,与 CSR 测试一致。" - ) - else: - print( - "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." - ) print( "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " @@ -1262,20 +1251,15 @@ def run_all_dtypes_spsv_csv(mtx_paths, csv_path, use_coo=False, coo_mode="auto", print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " - f"{'FlagSparse(ms)':>10} {cu_col:>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{fs_cu_hdr:>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FlagSparse(ms)':>10} {'COO(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/COO':>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: try: - if use_coo: - row, pt_skip = _run_one_csv_row_coo( - path, value_dtype, index_dtype, device, coo_mode, lower=lower - ) - else: - row, pt_skip = _run_one_csv_row_csr( - path, value_dtype, index_dtype, device, lower=lower - ) + row, pt_skip = _run_one_csv_row_coo( + path, value_dtype, index_dtype, device, coo_mode, lower=lower + ) rows_out.append(row) name = os.path.basename(path)[:27] if len(os.path.basename(path)) > 27: @@ -1453,13 +1437,30 @@ def main(): ) return if args.csv_coo: + if args.ops: + parser.error("--ops is only supported with --csv-csr; COO tests only run opA=NON") if not paths: paths = sorted(glob.glob("*.mtx")) if not paths: print("No .mtx files found for --csv-coo") return - run_all_dtypes_spsv_csv( - paths, args.csv_coo, use_coo=True, coo_mode=args.coo_mode, lower=lower + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + run_all_dtypes_spsv_coo_csv( + paths, + args.csv_coo, + coo_mode=args.coo_mode, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, ) return From 2b103136a247f98ec0b4d8d7a5d0f7c6335dcc9f Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 21 Apr 2026 15:35:59 +0800 Subject: [PATCH 05/22] Update spsv csr tests --- src/flagsparse/sparse_operations/spsv.py | 33 +- tests/test_spsv.py | 469 ++++++++++++++++++----- 2 files changed, 379 insertions(+), 123 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 5f6174c..ffdb6ea 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -9,12 +9,10 @@ import triton.language as tl SUPPORTED_SPSV_VALUE_DTYPES = ( - torch.bfloat16, torch.float32, torch.float64, torch.complex64, torch.complex128, - ) SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) SPSV_NON_TRANS_SUPPORTED_COMBOS = ( @@ -68,12 +66,10 @@ def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" if (data_dtype, index_dtype) in SPSV_NON_TRANS_SUPPORTED_COMBOS: return - if data_dtype == torch.bfloat16 and index_dtype == torch.int32: - return raise TypeError( f"{fmt_name} SpSV currently supports NON_TRANS combinations: " "(float32, int32/int64), (float64, int32/int64), " - "(complex64, int32/int64), (complex128, int32/int64), (bfloat16, int32)" + "(complex64, int32/int64), (complex128, int32/int64)" ) @@ -126,7 +122,7 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: raise TypeError( - "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" + "data dtype must be one of: float32, float64, complex64, complex128" ) if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: raise TypeError("indices dtype must be torch.int32 or torch.int64") @@ -696,14 +692,13 @@ def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") if data.dtype not in ( - torch.bfloat16, torch.float32, torch.float64, torch.complex64, torch.complex128, ): raise TypeError( - "data dtype must be one of: bfloat16, float32, float64, complex64, complex128" + "data dtype must be one of: float32, float64, complex64, complex128" ) if b.dtype != data.dtype: raise TypeError("b dtype must match data dtype") @@ -796,13 +791,8 @@ def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) out_nnz = unique_key.numel() - if data_s.dtype == torch.bfloat16: - reduced_f32 = torch.zeros(out_nnz, dtype=torch.float32, device=data.device) - reduced_f32.scatter_add_(0, inverse, data_s.to(torch.float32)) - data_u = reduced_f32.to(torch.bfloat16) - else: - data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) - data_u.scatter_add_(0, inverse, data_s) + data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) + data_u.scatter_add_(0, inverse, data_s) row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") col_u = unique_key - row_u * max(1, n_cols) @@ -877,7 +867,6 @@ def flagsparse_spsv_csr( Primary support matrix: - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices - TRANS/CONJ: float32/float64/complex64/complex128 with int32/int64 indices - - bfloat16 remains NON_TRANS + int32 """ input_data = data input_indices = indices @@ -926,11 +915,7 @@ def flagsparse_spsv_csr( compute_dtype = data.dtype data_in = kernel_data b_in = b - if data.dtype == torch.bfloat16: - compute_dtype = torch.float32 - data_in = kernel_data.to(torch.float32) - b_in = b.to(torch.float32) - elif data.dtype == torch.complex64 and trans_mode in ("T", "C"): + if data.dtype == torch.complex64 and trans_mode in ("T", "C"): compute_dtype = torch.complex128 data_in = kernel_data.to(torch.complex128) b_in = b.to(torch.complex128) @@ -1123,11 +1108,7 @@ def flagsparse_spsv_coo( compute_dtype = data.dtype data_in = data b_in = b - if data.dtype == torch.bfloat16: - compute_dtype = torch.float32 - data_in = data.to(torch.float32) - b_in = b.to(torch.float32) - elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: + if data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: compute_dtype = torch.float64 data_in = data.to(torch.float64) b_in = b.to(torch.float64) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index b64e6ff..cb28cb4 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -7,6 +7,7 @@ import argparse import csv import glob +import hashlib import os import sys from pathlib import Path @@ -19,6 +20,7 @@ sys.path.insert(0, str(_SRC_ROOT)) import flagsparse as fs +import flagsparse.sparse_operations.spsv as fs_spsv_impl try: import cupy as cp @@ -36,6 +38,7 @@ ITERS = 20 DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB +SPSV_TRIANGULAR_DIAG_DOMINANCE = 4.0 # CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) CSR_FULL_VALUE_DTYPES = [ torch.float32, @@ -111,12 +114,25 @@ def _tol_for_dtype(dtype): return 1e-12, 1e-10 -def _randn_by_dtype(n, dtype, device): +def _stable_case_seed(*parts): + raw = "|".join(str(part) for part in parts).encode("utf-8") + return int.from_bytes(hashlib.sha256(raw).digest()[:8], "little") % (2**63) + + +def _generator_for_seed(seed): + if seed is None: + return None + gen = torch.Generator() + gen.manual_seed(int(seed)) + return gen + + +def _randn_by_dtype(n, dtype, device, generator=None): if dtype in (torch.float32, torch.float64): - return torch.randn(n, dtype=dtype, device=device) + return torch.randn(n, dtype=dtype, device=device, generator=generator) base = torch.float32 if dtype == torch.complex64 else torch.float64 - real = torch.randn(n, dtype=base, device=device) - imag = torch.randn(n, dtype=base, device=device) + real = torch.randn(n, dtype=base, device=device, generator=generator) + imag = torch.randn(n, dtype=base, device=device, generator=generator) return torch.complex(real, imag) @@ -194,6 +210,8 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True rows_host = [] cols_host = [] vals_host = [] + row_off_abs = [0.0] * n + col_off_abs = [0.0] * n if value_dtype == torch.float32: base_real_dtype = torch.float32 elif value_dtype == torch.float64: @@ -223,28 +241,30 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), ) - sum_abs = ( - float(torch.sum(torch.abs(off_vals)).item()) if off_vals.numel() else 0.0 - ) - diag_imag = float( - torch.randn((), dtype=base_real_dtype, device=device).mul_(0.05).item() - ) - diag_val = complex(sum_abs + 1.0, diag_imag) off_vals_host = [complex(v) for v in off_vals.cpu().tolist()] else: off_vals = torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01) - sum_abs = ( - float(torch.sum(torch.abs(off_vals)).item()) if off_vals.numel() else 0.0 - ) - diag_val = sum_abs + 1.0 off_vals_host = off_vals.cpu().tolist() - rows_host.append(i) - cols_host.append(diag_col) - vals_host.append(diag_val) for c, v in zip(off_cols, off_vals_host): rows_host.append(i) cols_host.append(int(c)) vals_host.append(v) + mag = abs(v) + row_off_abs[i] += mag + col_off_abs[int(c)] += mag + + for i in range(n): + diag_mag = ( + SPSV_TRIANGULAR_DIAG_DOMINANCE * max(row_off_abs[i], col_off_abs[i]) + 1.0 + ) + diag_val = ( + complex(diag_mag, 0.0) + if value_dtype in (torch.complex64, torch.complex128) + else diag_mag + ) + rows_host.append(i) + cols_host.append(i) + vals_host.append(diag_val) rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) @@ -367,17 +387,27 @@ def _accum(r, c, v): elif mm_symmetry == "skew-symmetric" and r != c: _accum(c, r, -v) + tri_rows = [dict() for _ in range(n_rows)] + row_off_abs = [0.0] * n_rows + col_off_abs = [0.0] * n_cols for r in range(n_rows): - row = row_maps[r] - tri_row = {} - off_abs_sum = 0.0 - for c, v in row.items(): + for c, v in row_maps[r].items(): keep = c < r if lower else c > r if keep: - tri_row[c] = tri_row.get(c, 0.0) + v - off_abs_sum += abs(v) - tri_row[r] = off_abs_sum + 1.0 - row_maps[r] = tri_row + tri_rows[r][c] = tri_rows[r].get(c, 0.0) + v + + for r, row in enumerate(tri_rows): + for c, v in row.items(): + mag = abs(v) + row_off_abs[r] += mag + col_off_abs[c] += mag + + for r in range(n_rows): + # Make the generated triangular system stable for both A and op(A). + tri_rows[r][r] = ( + SPSV_TRIANGULAR_DIAG_DOMINANCE * max(row_off_abs[r], col_off_abs[r]) + 1.0 + ) + row_maps = tri_rows cols_s = [] vals_s = [] @@ -415,58 +445,53 @@ def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode, index_dtype=torc return data_c, row_c, col_c -def _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode): +def _random_rhs_for_spsv(shape, value_dtype, device, op_mode="NON", seed=None): + n_rows, n_cols = int(shape[0]), int(shape[1]) + rhs_size = n_rows if op_mode == "NON" else n_cols + if seed is None: + return _randn_by_dtype(rhs_size, value_dtype, device) + rhs = _randn_by_dtype( + rhs_size, + value_dtype, + torch.device("cpu"), + generator=_generator_for_seed(seed), + ) + return rhs.to(device) + + +def _apply_csr_op(data, indices, indptr, x, shape, op_mode): + n_rows, n_cols = int(shape[0]), int(shape[1]) + row = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr.to(torch.int64)[1:] - indptr.to(torch.int64)[:-1], + ) + col = indices.to(torch.int64) if op_mode == "NON": - b, _ = fs.flagsparse_spmv_csr( - data, indices, indptr, x_true, shape, return_time=True - ) + b = torch.zeros(n_rows, dtype=data.dtype, device=data.device) + b.scatter_add_(0, row, data * x[col]) return b if op_mode == "TRANS": - data_t, indices_t, indptr_t = _csr_transpose(data, indices, indptr, shape) - b, _ = fs.flagsparse_spmv_csr( - data_t, - indices_t.to(indices.dtype), - indptr_t.to(indptr.dtype), - x_true, - (shape[1], shape[0]), - return_time=True, - ) + b = torch.zeros(n_cols, dtype=data.dtype, device=data.device) + b.scatter_add_(0, col, data * x[row]) return b if op_mode == "CONJ": - data_h, indices_h, indptr_h = _csr_transpose( - data, indices, indptr, shape, conjugate=True - ) - b, _ = fs.flagsparse_spmv_csr( - data_h, - indices_h.to(indices.dtype), - indptr_h.to(indptr.dtype), - x_true, - (shape[1], shape[0]), - return_time=True, - ) + b = torch.zeros(n_cols, dtype=data.dtype, device=data.device) + data_eff = data.conj() if torch.is_complex(data) else data + b.scatter_add_(0, col, data_eff * x[row]) return b raise ValueError("op_mode must be 'NON', 'TRANS', or 'CONJ'") -def _known_solution_metrics(data, indices, indptr, shape, x, x_true, b, value_dtype, op_mode): +def _solution_residual_metrics(data, indices, indptr, shape, x, b, value_dtype, op_mode): atol, rtol = _tol_for_dtype(value_dtype) - x_cmp = _compare_view(x, value_dtype) - x_true_cmp = _compare_view(x_true, value_dtype) - err_x = ( - float(torch.max(torch.abs(x_cmp - x_true_cmp)).item()) - if x.numel() > 0 - else 0.0 - ) - ok_x = torch.allclose(x_cmp, x_true_cmp, atol=atol, rtol=rtol) - - b_recon = _build_rhs_for_csr_op(data, indices, indptr, x, shape, op_mode) + b_recon = _apply_csr_op(data, indices, indptr, x, shape, op_mode) err_res = ( float(torch.max(torch.abs(b_recon - b)).item()) if b.numel() > 0 else 0.0 ) ok_res = torch.allclose(b_recon, b, atol=atol, rtol=rtol) - return err_x, ok_x, err_res, ok_res + return err_res, ok_res def _cupy_spsolve_lower_csr_or_coo( @@ -624,11 +649,22 @@ def run_spsv_synthetic_all(lower=True): A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr, shape ) - x_true = _randn_by_dtype(n, value_dtype, device) - if fmt == "CSR": - b = _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode) - else: - b = A_dense @ x_true + rhs_op = op_mode if fmt == "CSR" else "NON" + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=rhs_op, + seed=_stable_case_seed( + "synthetic", + "LOWER" if lower else "UPPER", + fmt, + op_mode, + n, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) torch.cuda.synchronize() if fmt == "CSR": @@ -743,9 +779,19 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower ) indices = indices.to(index_dtype) n_rows, n_cols = shape - x_true = _randn_by_dtype(n_rows, value_dtype, device) - b, _ = fs.flagsparse_spmv_csr( - data, indices, indptr, x_true, shape, return_time=True + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode="NON", + seed=_stable_case_seed( + "csv-coo", + os.path.basename(path), + "LOWER" if lower else "UPPER", + coo_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), ) d_in, r_in, c_in = _coo_inputs_for_csv( data, indices, indptr, shape, coo_mode, index_dtype=index_dtype @@ -771,7 +817,6 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower x, t_ms, b, - x_true, n_rows, n_cols, lower=lower, @@ -793,7 +838,6 @@ def _finalize_csv_row( x, t_ms, b, - x_true, n_rows, n_cols, *, @@ -804,8 +848,8 @@ def _finalize_csv_row( cupy_coo_col=None, ): atol, rtol = _tol_for_dtype(value_dtype) - err_x, ok_x, err_res, ok_res = _known_solution_metrics( - data, indices, indptr, shape, x, x_true, b, value_dtype, "NON" + err_res, _ = _solution_residual_metrics( + data, indices, indptr, shape, x, b, value_dtype, "NON" ) pytorch_ms = None err_pt = None @@ -927,6 +971,8 @@ def _finalize_csv_row( status = "PASS" if (ok_pt or ok_cu) else "FAIL" if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): status = "REF_FAIL" + ref_errors = [err for err in (err_pt, err_cu) if err is not None] + err_ref = min(ref_errors) if ref_errors else None nnz_out = ( int(data.numel()) if nnz_display is None else int(nnz_display) @@ -943,7 +989,7 @@ def _finalize_csv_row( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, - "err_x": err_x, + "err_ref": err_ref, "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, @@ -958,8 +1004,20 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l indices = indices.to(index_dtype) indptr = indptr.to(index_dtype) n_rows, n_cols = shape - x_true = _randn_by_dtype(n_rows, value_dtype, device) - b = _build_rhs_for_csr_op(data, indices, indptr, x_true, shape, op_mode) + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=op_mode, + seed=_stable_case_seed( + "csv-csr", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) x, t_ms = fs.flagsparse_spsv_csr( data, indices, @@ -982,7 +1040,6 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l x, t_ms, b, - x_true, n_rows, n_cols, lower=lower, @@ -1001,14 +1058,13 @@ def _finalize_csv_row_csr_full( x, t_ms, b, - x_true, n_rows, n_cols, lower=True, ): atol, rtol = _tol_for_dtype(value_dtype) - err_x, ok_x, err_res, ok_res = _known_solution_metrics( - data, indices, indptr, shape, x, x_true, b, value_dtype, op_mode + err_res, _ = _solution_residual_metrics( + data, indices, indptr, shape, x, b, value_dtype, op_mode ) pytorch_ms = None @@ -1068,6 +1124,8 @@ def _finalize_csv_row_csr_full( status = "PASS" if (ok_pt or ok_cu) else "FAIL" if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): status = "REF_FAIL" + ref_errors = [err for err in (err_pt, err_cu) if err is not None] + err_ref = min(ref_errors) if ref_errors else None row = { "matrix": os.path.basename(path), @@ -1082,7 +1140,7 @@ def _finalize_csv_row_csr_full( "cusparse_ms": cupy_ms, "csc_ms": None, "status": status, - "err_x": err_x, + "err_ref": err_ref, "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, @@ -1121,15 +1179,16 @@ def run_all_supported_spsv_csr_csv( "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." ) print( - "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " + "RHS is generated directly, matching Library-main's SpSV test style. " + "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|op(A)*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if PyTorch / cuSPARSE reference passes. x_true / residual are diagnostics only." + "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FS/CSR':>7} {'FS/PT':>7} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: @@ -1146,14 +1205,14 @@ def run_all_supported_spsv_csr_csv( t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] - err_x, err_res = row["err_x"], row["err_res"] + err_ref, err_res = row["err_ref"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_x):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -1174,7 +1233,7 @@ def run_all_supported_spsv_csr_csv( "cusparse_ms": None, "csc_ms": None, "status": status, - "err_x": None, + "err_ref": None, "err_res": None, "err_pt": None, "err_cu": None, @@ -1204,7 +1263,7 @@ def run_all_supported_spsv_csr_csv( "cusparse_ms", "csc_ms", "status", - "err_x", + "err_ref", "err_res", "err_pt", "err_cu", @@ -1241,18 +1300,18 @@ def run_all_dtypes_spsv_coo_csv( ) print( "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " - "b 由 CSR SpMV 构造,与 CSR 测试一致。" + "RHS is generated directly, matching Library-main's SpSV test style." ) print( - "Err(X)=|FlagSparse-x_true|, Err(Res)=|A*x-b|, " + "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|A*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " - "PASS if PyTorch / cuSPARSE reference passes. x_true / residual are diagnostics only." + "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." ) print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " f"{'FlagSparse(ms)':>10} {'COO(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " - f"{'FS/COO':>7} {'FS/PT':>7} {'Status':>6} {'Err(X)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + f"{'FS/COO':>7} {'FS/PT':>7} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) for path in mtx_paths: @@ -1269,14 +1328,14 @@ def run_all_dtypes_spsv_coo_csv( t_ms = row["triton_ms"] cupy_ms = row["cusparse_ms"] pytorch_ms = row["pytorch_ms"] - err_x, err_res = row["err_x"], row["err_res"] + err_ref, err_res = row["err_ref"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(None):>10} {_fmt_ms(pytorch_ms):>11} " f"{_fmt_speedup(cupy_ms, t_ms):>7} {_fmt_speedup(pytorch_ms, t_ms):>7} " - f"{status:>6} {_fmt_err(err_x):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if pt_skip: print(f" NOTE: {pt_skip}") @@ -1296,7 +1355,7 @@ def run_all_dtypes_spsv_coo_csv( "cusparse_ms": None, "csc_ms": None, "status": status, - "err_x": None, + "err_ref": None, "err_res": None, "err_pt": None, "err_cu": None, @@ -1324,7 +1383,7 @@ def run_all_dtypes_spsv_coo_csv( "cusparse_ms", "csc_ms", "status", - "err_x", + "err_ref", "err_res", "err_pt", "err_cu", @@ -1337,6 +1396,188 @@ def run_all_dtypes_spsv_coo_csv( print(f"Wrote {len(rows_out)} rows to {csv_path}") +def _check_one_csr_transpose_case(path, value_dtype, index_dtype, op_mode, device, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower + ) + indices = indices.to(index_dtype) + indptr = indptr.to(index_dtype) + n_rows, n_cols = shape + trans_data, trans_indices64, trans_indptr64 = fs_spsv_impl._csr_transpose( + data, + indices.to(torch.int64), + indptr.to(torch.int64), + n_rows, + n_cols, + conjugate=(op_mode == "CONJ"), + ) + trans_shape = (n_cols, n_rows) + trans_indices = trans_indices64.to(index_dtype) + trans_indptr = trans_indptr64.to(index_dtype) + + probe = _random_rhs_for_spsv( + trans_shape, + value_dtype, + device, + op_mode="NON", + seed=_stable_case_seed( + "check-transpose-action", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + action_ref = _apply_csr_op(data, indices, indptr, probe, shape, op_mode) + action_trans = _apply_csr_op( + trans_data, trans_indices, trans_indptr, probe, trans_shape, "NON" + ) + action_err = ( + float(torch.max(torch.abs(action_trans - action_ref)).item()) + if action_ref.numel() > 0 + else 0.0 + ) + atol, rtol = _tol_for_dtype(value_dtype) + action_ok = torch.allclose(action_trans, action_ref, atol=atol, rtol=rtol) + + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=op_mode, + seed=_stable_case_seed( + "check-transpose-solve", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + x_op = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=op_mode, + ) + x_mat = fs.flagsparse_spsv_csr( + trans_data, + trans_indices, + trans_indptr, + b, + trans_shape, + lower=not lower, + transpose="NON", + ) + solve_err = ( + float(torch.max(torch.abs(x_op - x_mat)).item()) if x_op.numel() > 0 else 0.0 + ) + solve_ok = torch.allclose(x_op, x_mat, atol=atol, rtol=rtol) + + ref_err = None + ref_ok = None + if _allow_dense_pytorch_ref(shape, value_dtype): + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) + ref_err = ( + float(torch.max(torch.abs(x_op - x_ref)).item()) if x_op.numel() > 0 else 0.0 + ) + ref_ok = torch.allclose(x_op, x_ref, atol=atol, rtol=rtol) + + status = "PASS" if action_ok and solve_ok and (ref_ok is not False) else "FAIL" + return { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": n_rows, + "nnz": int(data.numel()), + "action_err": action_err, + "solve_err": solve_err, + "ref_err": ref_err, + "status": status, + } + + +def run_csr_transpose_check( + mtx_paths, + lower=True, + value_dtypes=None, + index_dtypes=None, + op_modes=None, +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + selected_value_dtypes = value_dtypes or CSR_FULL_VALUE_DTYPES + selected_index_dtypes = index_dtypes or CSR_FULL_INDEX_DTYPES + selected_op_modes = [op for op in (op_modes or ("TRANS", "CONJ")) if op in ("TRANS", "CONJ")] + if not selected_op_modes: + print("--check-transpose only checks TRANS/CONJ; no matching op selected.") + return + + print("=" * 150) + print( + "CSR TRANS/CONJ preprocessing check: " + "ActionErr compares materialized op(A) against direct CSR scatter; " + "SolveErr compares transpose path against materialized NON path." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'dtype':>10} {'index':>7} {'opA':>5} " + f"{'N':>7} {'NNZ':>10} {'Status':>6} {'ActionErr':>10} {'SolveErr':>10} {'RefErr':>10}" + ) + print("-" * 150) + total = 0 + failed = 0 + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + for op_mode in selected_op_modes: + for path in mtx_paths: + try: + row = _check_one_csr_transpose_case( + path, + value_dtype, + index_dtype, + op_mode, + device, + lower=lower, + ) + total += 1 + failed += int(row["status"] != "PASS") + name = row["matrix"][:27] + if len(row["matrix"]) > 27: + name += "..." + print( + f"{name:<28} {row['value_dtype']:>10} {row['index_dtype']:>7} {row['opA']:>5} " + f"{row['n_rows']:>7} {row['nnz']:>10} {row['status']:>6} " + f"{_fmt_err(row['action_err']):>10} {_fmt_err(row['solve_err']):>10} {_fmt_err(row['ref_err']):>10}" + ) + except Exception as e: + total += 1 + failed += 1 + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name += "..." + print( + f"{name:<28} {_dtype_name(value_dtype):>10} {_dtype_name(index_dtype):>7} {op_mode:>5} " + f"{'ERR':>7} {'ERR':>10} {'ERROR':>6} " + f"{_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" ERROR: {e}") + print("-" * 150) + print(f"Total cases: {total} Failed: {failed}") + + def main(): parser = argparse.ArgumentParser( description="SpSV test: synthetic triangular systems and optional .mtx (CSR/COO), same baselines as CSR." @@ -1363,6 +1604,11 @@ def main(): metavar="FILE", help="Run all dtypes on .mtx (COO SpSV), same CSV columns as --csv-csr", ) + parser.add_argument( + "--check-transpose", + action="store_true", + help="Check CSR TRANS/CONJ preprocessing against direct CSR scatter and materialized NON solve", + ) parser.add_argument( "--coo-mode", type=str, @@ -1406,6 +1652,35 @@ def main(): paths.append(p) elif os.path.isdir(p): paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + if args.check_transpose: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --check-transpose") + return + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + op_modes = ( + _parse_op_modes_filter(args.ops) + if args.ops + else None + ) + run_csr_transpose_check( + paths, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + op_modes=op_modes, + ) + return if args.csv_csr: if not paths: paths = sorted(glob.glob("*.mtx")) From 5bf5ee4ea7f38435dd9f46e2d20ac7429c3c669d Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 21 Apr 2026 21:43:25 +0800 Subject: [PATCH 06/22] Guard SPSV timing and benchmark gather paths --- src/flagsparse/sparse_operations/spsv.py | 20 ++++--- tests/test_gather.py | 7 +-- tests/test_spsv.py | 73 +++++++++++++++++++++--- 3 files changed, 79 insertions(+), 21 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index ffdb6ea..29adfa1 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -941,8 +941,9 @@ def flagsparse_spsv_csr( vec_real = _triton_spsv_csr_vector vec_complex = _triton_spsv_csr_vector_complex - torch.cuda.synchronize() - t0 = time.perf_counter() + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() if b_in.ndim == 1: if torch.is_complex(data_in): x = vec_complex( @@ -1020,8 +1021,9 @@ def flagsparse_spsv_csr( target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype if x.dtype != target_dtype: x = _restore_spsv_output(x, target_dtype) - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: if out.shape != x.shape or out.dtype != x.dtype: raise ValueError("out shape/dtype must match result") @@ -1118,8 +1120,9 @@ def flagsparse_spsv_coo( ) diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) - torch.cuda.synchronize() - t0 = time.perf_counter() + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() if b_in.ndim == 1: x = _triton_spsv_coo_vector( data_in, @@ -1159,8 +1162,9 @@ def flagsparse_spsv_coo( x = torch.stack(cols_out, dim=1) if compute_dtype != data.dtype: x = x.to(data.dtype) - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: if out.shape != x.shape or out.dtype != x.dtype: diff --git a/tests/test_gather.py b/tests/test_gather.py index b00a6c4..2bc2dd1 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -163,15 +163,12 @@ def _collect_samples(case_id, expected, flagsparse_out, limit): def _dtype_mode(value_dtype_req): - if value_dtype_req in ("float16", "bfloat16", "complex64"): - return "gather_cupy" return "gather_triton" def _select_mode(value_dtype_req, index_dtype): - # Keep original gather path for half+int32 while retaining cupy path for new combos. - if value_dtype_req == "float16" and index_dtype == torch.int32: - return "gather_triton" + # The required gather coverage is the full 6 value dtypes x 2 index dtypes + # matrix, so benchmark the primary Triton gather path for every combo. return _dtype_mode(value_dtype_req) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index cb28cb4..e4a8846 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -494,6 +494,67 @@ def _solution_residual_metrics(data, indices, indptr, shape, x, b, value_dtype, return err_res, ok_res +def _benchmark_flagsparse(call): + x = None + for _ in range(WARMUP): + x = call() + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + for _ in range(ITERS): + x = call() + e1.record() + torch.cuda.synchronize() + return x, e0.elapsed_time(e1) / ITERS + + +def _benchmark_flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + *, + lower=True, + transpose=False, +): + return _benchmark_flagsparse( + lambda: fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=transpose, + ) + ) + + +def _benchmark_flagsparse_spsv_coo( + data, + row, + col, + b, + shape, + *, + lower=True, + coo_mode="auto", +): + return _benchmark_flagsparse( + lambda: fs.flagsparse_spsv_coo( + data, + row, + col, + b, + shape, + lower=lower, + coo_mode=coo_mode, + ) + ) + + def _cupy_spsolve_lower_csr_or_coo( fmt, data, @@ -668,7 +729,7 @@ def run_spsv_synthetic_all(lower=True): torch.cuda.synchronize() if fmt == "CSR": - x, t_ms = fs.flagsparse_spsv_csr( + x, t_ms = _benchmark_flagsparse_spsv_csr( data, indices, indptr, @@ -676,13 +737,12 @@ def run_spsv_synthetic_all(lower=True): shape, lower=lower, transpose=op_mode, - return_time=True, ) else: dc, rr, cc = _csr_to_coo( data, indices, indptr, shape, index_dtype=index_dtype ) - x, t_ms = fs.flagsparse_spsv_coo( + x, t_ms = _benchmark_flagsparse_spsv_coo( dc, rr, cc, @@ -690,7 +750,6 @@ def run_spsv_synthetic_all(lower=True): shape, lower=lower, coo_mode="auto", - return_time=True, ) torch.cuda.synchronize() @@ -796,7 +855,7 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower d_in, r_in, c_in = _coo_inputs_for_csv( data, indices, indptr, shape, coo_mode, index_dtype=index_dtype ) - x, t_ms = fs.flagsparse_spsv_coo( + x, t_ms = _benchmark_flagsparse_spsv_coo( d_in, r_in, c_in, @@ -804,7 +863,6 @@ def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower shape, lower=lower, coo_mode=coo_mode, - return_time=True, ) return _finalize_csv_row( path, @@ -1018,7 +1076,7 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l _dtype_name(index_dtype), ), ) - x, t_ms = fs.flagsparse_spsv_csr( + x, t_ms = _benchmark_flagsparse_spsv_csr( data, indices, indptr, @@ -1026,7 +1084,6 @@ def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, l shape, lower=lower, transpose=op_mode, - return_time=True, ) return _finalize_csv_row_csr_full( path, From 1b71727b8b33d5de53274930908f84be4b7ea299 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 21 Apr 2026 22:33:19 +0800 Subject: [PATCH 07/22] =?UTF-8?q?=E6=98=AF=EF=BC=8C=E5=B7=B2=E7=BB=8F?= =?UTF-8?q?=E7=AC=A6=E5=90=88=E8=A6=81=E6=B1=82=EF=BC=9A12=20=E7=BB=84=20g?= =?UTF-8?q?ather=20=E7=BB=84=E5=90=88=E9=83=BD=E5=9C=A8=EF=BC=8C=E4=B8=BB?= =?UTF-8?q?=E8=B7=AF=E5=BE=84=E5=8F=AA=E8=B5=B0=20Triton=EF=BC=8C`flagspar?= =?UTF-8?q?se=5Fgather=5Fcupy`/`cusptr=3F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/flagsparse/__init__.py | 4 - src/flagsparse/sparse_operations/__init__.py | 4 - .../sparse_operations/gather_scatter.py | 328 ------------------ tests/pytest/test_gather_scatter_accuracy.py | 219 +----------- tests/test_gather.py | 49 +-- 5 files changed, 31 insertions(+), 573 deletions(-) diff --git a/src/flagsparse/__init__.py b/src/flagsparse/__init__.py index 21aa8a6..02b5d76 100644 --- a/src/flagsparse/__init__.py +++ b/src/flagsparse/__init__.py @@ -4,12 +4,10 @@ __all__ = [ "flagsparse_gather", - "flagsparse_gather_cupy", "flagsparse_scatter", "pytorch_index_gather", "pytorch_index_scatter", "cusparse_spmv_gather", - "cusparse_spmv_gather_cupy", "cusparse_spmv_scatter", "benchmark_gather_case", "benchmark_scatter_case", @@ -70,12 +68,10 @@ _OPS_EXPORTS = { "flagsparse_gather", - "flagsparse_gather_cupy", "flagsparse_scatter", "pytorch_index_gather", "pytorch_index_scatter", "cusparse_spmv_gather", - "cusparse_spmv_gather_cupy", "cusparse_spmv_scatter", "benchmark_gather_case", "benchmark_scatter_case", diff --git a/src/flagsparse/sparse_operations/__init__.py b/src/flagsparse/sparse_operations/__init__.py index c7db901..0fac868 100644 --- a/src/flagsparse/sparse_operations/__init__.py +++ b/src/flagsparse/sparse_operations/__init__.py @@ -18,10 +18,8 @@ ) from .gather_scatter import ( cusparse_spmv_gather, - cusparse_spmv_gather_cupy, cusparse_spmv_scatter, flagsparse_gather, - flagsparse_gather_cupy, flagsparse_scatter, pytorch_index_gather, pytorch_index_scatter, @@ -75,10 +73,8 @@ "comprehensive_spmm_test", "comprehensive_spsm_test", "cusparse_spmv_gather", - "cusparse_spmv_gather_cupy", "cusparse_spmv_scatter", "flagsparse_gather", - "flagsparse_gather_cupy", "flagsparse_sddmm_csr", "flagsparse_spgemm_csr", "flagsparse_spmm_coo", diff --git a/src/flagsparse/sparse_operations/gather_scatter.py b/src/flagsparse/sparse_operations/gather_scatter.py index 67066b0..ab417a6 100644 --- a/src/flagsparse/sparse_operations/gather_scatter.py +++ b/src/flagsparse/sparse_operations/gather_scatter.py @@ -539,331 +539,3 @@ def cusparse_spmv_scatter( ) from exc return dense_values, execution_time_ms, selector_matrix - - -_CUPY_GATHER_KERNEL_CACHE = {} -_CUPY_RAW_KIND_TO_DTYPE = { - 16: cp.uint16 if cp is not None else None, - 32: cp.uint32 if cp is not None else None, - 64: cp.uint64 if cp is not None else None, -} -_CUPY_RAW_KIND_TO_CTYPE = { - 16: "unsigned short", - 32: "unsigned int", - 64: "unsigned long long", -} -_CUPY_INDEX_TO_CTYPE = { - "i32": "int", - "i64": "long long", -} - - -def _cupy_gather_detect_layout(dense_vector): - if dense_vector.ndim == 1 and dense_vector.dtype in (torch.float16, torch.bfloat16): - return "scalar16" - if dense_vector.ndim == 1 and dense_vector.dtype == torch.complex64: - return "complex64" - raise TypeError( - "Unsupported gather input format. Expected one of: " - "1D float16/bfloat16 or 1D complex64." - ) - - -def _cupy_gather_dense_size(dense_vector, layout): - if layout in ("scalar16", "complex64"): - return int(dense_vector.shape[0]) - raise RuntimeError(f"Unknown gather layout: {layout}") - - -def _cupy_gather_validate_inputs(dense_vector, indices): - if not dense_vector.is_cuda or not indices.is_cuda: - raise ValueError("dense_vector and indices must both be CUDA tensors") - if indices.ndim != 1: - raise ValueError("indices must be a 1D tensor") - if indices.dtype not in SUPPORTED_INDEX_DTYPES: - raise TypeError("indices dtype must be torch.int32 or torch.int64") - - layout = _cupy_gather_detect_layout(dense_vector) - dense_size = _cupy_gather_dense_size(dense_vector, layout) - - if indices.numel() > 0: - if torch.any(indices < 0).item(): - raise IndexError("indices must be non-negative") - max_index = int(indices.max().item()) - if max_index >= dense_size: - raise IndexError( - f"indices out of range: max index {max_index}, dense size {dense_size}" - ) - - return layout, dense_size - - -def _cupy_gather_validate_combo(dense_vector, indices, layout): - # Keep only the required extra gather combos: - # Half+Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64 - if layout == "scalar16": - if dense_vector.dtype == torch.float16: - if indices.dtype != torch.int64: - raise TypeError("float16 gather_cupy supports only int64 indices") - return - if dense_vector.dtype == torch.bfloat16: - return - raise TypeError("scalar16 gather_cupy supports only float16/bfloat16") - - if layout == "complex64": - return - - raise TypeError(f"Unsupported gather_cupy layout: {layout}") - - -def _cupy_gather_layout_raw_kind(layout): - if layout == "scalar16": - return 16 - if layout == "complex64": - return 64 - raise RuntimeError(f"Unknown gather layout: {layout}") - - -def _cupy_gather_index_tag(index_dtype): - return "i32" if index_dtype == torch.int32 else "i64" - - -def _cupy_gather_kernel_name(raw_kind, index_tag): - return f"flagsparse_gather_cupy_raw{raw_kind}_{index_tag}" - - -def _cupy_gather_build_source(raw_kind, index_tag): - value_ctype = _CUPY_RAW_KIND_TO_CTYPE[raw_kind] - index_ctype = _CUPY_INDEX_TO_CTYPE[index_tag] - kernel_name = _cupy_gather_kernel_name(raw_kind, index_tag) - return f""" -extern "C" __global__ -void {kernel_name}(const {value_ctype}* y, const {index_ctype}* x_ind, {value_ctype}* x_val, long long nnz) -{{ - long long tid = (long long)threadIdx.x + (long long)blockIdx.x * blockDim.x; - long long stride = (long long)gridDim.x * blockDim.x; - for (long long i = tid; i < nnz; i += stride) - {{ - long long dense_idx = (long long)x_ind[i]; - x_val[i] = y[dense_idx]; - }} -}} -""" - - -def _cupy_gather_get_kernel(raw_kind, index_dtype): - _require_cupy() - index_tag = _cupy_gather_index_tag(index_dtype) - key = (raw_kind, index_tag) - if key not in _CUPY_GATHER_KERNEL_CACHE: - src = _cupy_gather_build_source(raw_kind, index_tag) - name = _cupy_gather_kernel_name(raw_kind, index_tag) - _CUPY_GATHER_KERNEL_CACHE[key] = cp.RawKernel(src, name, options=("--std=c++14",)) - return _CUPY_GATHER_KERNEL_CACHE[key] - - -def _cupy_gather_dense_to_raw_torch(dense_t, layout): - if layout == "scalar16": - return dense_t.reshape(-1).view(torch.uint16) - if layout == "complex64": - return dense_t.reshape(-1).view(torch.uint64) - raise RuntimeError(f"Unknown gather layout: {layout}") - - -def _cupy_gather_raw_to_dense_torch(out_raw_t, layout, dense_t_dtype): - if layout == "scalar16": - return out_raw_t.view(dense_t_dtype).reshape(-1) - if layout == "complex64": - return out_raw_t.view(torch.complex64).reshape(-1) - raise RuntimeError(f"Unknown gather layout: {layout}") - - -def _cupy_gather_empty(layout, dense_dtype, device): - if layout in ("scalar16", "complex64"): - return torch.empty(0, dtype=dense_dtype, device=device) - raise RuntimeError(f"Unknown gather layout: {layout}") - - -def _cupy_gather_selector_dtype(layout, dense_dtype): - if layout == "scalar16": - return dense_dtype - if layout == "complex64": - return torch.complex64 - raise RuntimeError(f"Unknown gather layout: {layout}") - - -def _cupy_gather_prepare_dense(dense_vector, indices): - layout, dense_size = _cupy_gather_validate_inputs(dense_vector, indices) - _cupy_gather_validate_combo(dense_vector, indices, layout) - return dense_vector, layout, dense_size, None - - -def _cupy_gather_restore_output(gathered_t, restore_mode): - return gathered_t - - -def _launch_cupy_gather_kernel(gather_kernel, dense_raw_cp, indices_cp, out_raw_cp, nnz): - thread_per_block = 256 - block_per_grid = min(2, (nnz + thread_per_block - 1) // thread_per_block) - if block_per_grid <= 0: - block_per_grid = 1 - gather_kernel( - (int(block_per_grid),), - (int(thread_per_block),), - (dense_raw_cp, indices_cp, out_raw_cp, cp.int64(nnz)), - ) - - -def flagsparse_gather_cupy( - a, - indices, - out=None, - mode="raise", - index_fallback_policy="auto", - return_time=False, - return_metadata=False, -): - """CuPy-style gather (take) for extra gather dtype/index combinations.""" - if mode != "raise": - raise NotImplementedError("Only mode='raise' is currently supported") - index_fallback_policy = str(index_fallback_policy).lower() - if index_fallback_policy not in ("auto", "strict"): - raise ValueError("index_fallback_policy must be 'auto' or 'strict'") - - _require_cupy() - dense_t, dense_backend = _to_torch_tensor(a, "a") - indices_t, _ = _to_torch_tensor(indices, "indices") - - dense_t = dense_t.contiguous() - indices_t = indices_t.contiguous() - runtime_dense_t, layout, _, restore_mode = _cupy_gather_prepare_dense(dense_t, indices_t) - gather_meta = { - "index_fallback_applied": False, - "index_fallback_reason": None, - "kernel_index_dtype": str(indices_t.dtype).replace("torch.", ""), - } - - nnz = int(indices_t.numel()) - if nnz == 0: - gathered_t = _cupy_gather_empty(layout, runtime_dense_t.dtype, runtime_dense_t.device) - execution_time_ms = 0.0 - else: - raw_kind = _cupy_gather_layout_raw_kind(layout) - raw_dtype = _CUPY_RAW_KIND_TO_DTYPE[raw_kind] - indices_cp = _cupy_from_torch(indices_t) - dense_raw_t = _cupy_gather_dense_to_raw_torch(runtime_dense_t, layout) - dense_raw_cp = _cupy_from_torch(dense_raw_t) - out_raw_cp = cp.empty(nnz, dtype=raw_dtype) - gather_kernel = _cupy_gather_get_kernel(raw_kind, indices_t.dtype) - - torch.cuda.synchronize() - start_time = time.perf_counter() - stream_ptr = torch.cuda.current_stream(device=runtime_dense_t.device).cuda_stream - try: - with cp.cuda.ExternalStream(stream_ptr): - _launch_cupy_gather_kernel( - gather_kernel, - dense_raw_cp, - indices_cp, - out_raw_cp, - nnz, - ) - except Exception as exc: - if indices_t.dtype != torch.int64 or index_fallback_policy != "auto": - raise RuntimeError( - f"CuPy gather failed for index dtype {indices_t.dtype}: " - f"{exc.__class__.__name__}: {str(exc)}" - ) from exc - - max_index = int(indices_t.max().item()) if nnz > 0 else -1 - if max_index > _INDEX_LIMIT_INT32: - raise RuntimeError( - "CuPy gather failed for int64 indices, and int32 fallback is invalid: " - f"max index {max_index} exceeds int32 range" - ) from exc - - fallback_indices_t = indices_t.to(torch.int32) - fallback_indices_cp = _cupy_from_torch(fallback_indices_t) - fallback_kernel = _cupy_gather_get_kernel(raw_kind, torch.int32) - try: - with cp.cuda.ExternalStream(stream_ptr): - _launch_cupy_gather_kernel( - fallback_kernel, - dense_raw_cp, - fallback_indices_cp, - out_raw_cp, - nnz, - ) - except Exception as fallback_exc: - raise RuntimeError( - "CuPy gather failed for int64 indices, and int32 fallback also failed: " - f"{fallback_exc.__class__.__name__}: {str(fallback_exc)}" - ) from fallback_exc - gather_meta["index_fallback_applied"] = True - gather_meta["index_fallback_reason"] = ( - f"int64 kernel launch failed: {exc.__class__.__name__}: {str(exc)}" - ) - gather_meta["kernel_index_dtype"] = "int32" - torch.cuda.synchronize() - execution_time_ms = (time.perf_counter() - start_time) * 1000.0 - - out_raw_t = _torch_from_cupy(out_raw_cp) - gathered_t = _cupy_gather_raw_to_dense_torch(out_raw_t, layout, runtime_dense_t.dtype) - - gathered_t = _cupy_gather_restore_output(gathered_t, restore_mode) - - if out is not None: - out_t, _ = _to_torch_tensor(out, "out") - if out_t.shape != gathered_t.shape: - raise ValueError("out shape must match gather output shape") - if out_t.dtype != gathered_t.dtype: - raise TypeError("out dtype must match gather output dtype") - if not out_t.is_cuda: - raise ValueError("out must be a CUDA tensor/array") - out_t.copy_(gathered_t) - result = out if dense_backend == "cupy" else out_t - else: - result = _to_backend_like(gathered_t, a) - - if return_time and return_metadata: - return result, execution_time_ms, gather_meta - if return_time: - return result, execution_time_ms - if return_metadata: - return result, gather_meta - return result - - -def cusparse_spmv_gather_cupy(dense_vector, indices, selector_matrix=None): - """Equivalent gather baseline via cuSPARSE-backed COO SpMV for cupy gather path.""" - _require_cupy() - dense_t, _ = _to_torch_tensor(dense_vector, "dense_vector") - indices_t, _ = _to_torch_tensor(indices, "indices") - - dense_t = dense_t.contiguous() - indices_t = indices_t.contiguous() - runtime_dense_t, layout, dense_size, restore_mode = _cupy_gather_prepare_dense( - dense_t, indices_t - ) - - selector_dtype = _cupy_gather_selector_dtype(layout, runtime_dense_t.dtype) - if selector_matrix is None: - selector_matrix = _make_gather_selector_matrix(indices_t, dense_size, selector_dtype) - - try: - torch.cuda.synchronize() - start_time = time.perf_counter() - if layout in ("scalar16", "complex64"): - gathered_t = _cusparse_spmv(selector_matrix, runtime_dense_t) - else: - raise RuntimeError(f"Unknown gather layout: {layout}") - torch.cuda.synchronize() - execution_time_ms = (time.perf_counter() - start_time) * 1000.0 - except Exception as exc: - raise RuntimeError( - "cuSPARSE gather baseline is unavailable in this PyTorch/CUDA environment" - ) from exc - - gathered_t = _cupy_gather_restore_output(gathered_t, restore_mode) - result = _to_backend_like(gathered_t, dense_vector) - return result, execution_time_ms, selector_matrix diff --git a/tests/pytest/test_gather_scatter_accuracy.py b/tests/pytest/test_gather_scatter_accuracy.py index e67c648..a30beec 100644 --- a/tests/pytest/test_gather_scatter_accuracy.py +++ b/tests/pytest/test_gather_scatter_accuracy.py @@ -1,16 +1,10 @@ import pytest import torch -from flagsparse import ( - cusparse_spmv_gather_cupy, - flagsparse_gather, - flagsparse_gather_cupy, - flagsparse_scatter, -) +from flagsparse import flagsparse_gather, flagsparse_scatter from flagsparse.sparse_operations import gather_scatter as gather_scatter_ops -from flagsparse.sparse_operations._common import cp -from tests.pytest.param_shapes import FLOAT_DTYPE_IDS, FLOAT_DTYPES, GATHER_SCATTER_SHAPES +from tests.pytest.param_shapes import FLOAT_DTYPES, GATHER_SCATTER_SHAPES pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") @@ -19,6 +13,17 @@ RESET_OUTPUT_CASES = [True, False] RESET_OUTPUT_IDS = ["reset", "inplace"] +GATHER_DTYPE_CASES = [ + ("float", torch.float32), + ("double", torch.float64), + ("half", torch.float16), + ("bfloat16", torch.bfloat16), + ("complex64", torch.complex64), + ("complex128", torch.complex128), +] +GATHER_DTYPE_IDS = [name for name, _ in GATHER_DTYPE_CASES] + + def _scatter_dtype_cases(): cases = [(str(dtype).replace("torch.", ""), dtype) for dtype in FLOAT_DTYPES] cases.append(("complex64", torch.complex64)) @@ -53,75 +58,15 @@ def _build_random_values(size, dtype, device): raise TypeError(f"Unsupported dtype in test: {dtype}") -def _to_cupy(tensor): - return cp.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)) - - -def _to_torch(array): - try: - dlpack_capsule = array.toDlpack() - except AttributeError: - dlpack_capsule = array.to_dlpack() - return torch.utils.dlpack.from_dlpack(dlpack_capsule) - - -def _to_backend_tensor(tensor, backend): - if backend == "torch": - return tensor - return _to_cupy(tensor) - - -def _as_torch_tensor(value): - return value if torch.is_tensor(value) else _to_torch(value) - - -def _build_extra_gather_dense(layout, value_dtype, dense_size, device): - if layout == "scalar16": - return torch.randn(dense_size, dtype=value_dtype, device=device) - if layout == "complex16_pair": - return torch.randn(dense_size, 2, dtype=value_dtype, device=device) - if layout == "complex64": - real = torch.randn(dense_size, dtype=torch.float32, device=device) - imag = torch.randn(dense_size, dtype=torch.float32, device=device) - return torch.complex(real, imag) - raise RuntimeError(f"Unknown layout: {layout}") - - -def _extra_gather_tolerance(value_dtype): - if value_dtype == torch.float16: - return 5e-3, 5e-3 - if value_dtype == torch.bfloat16: - return 1e-2, 1e-2 - if value_dtype == torch.complex64: - return 1e-6, 1e-5 - return 1e-6, 1e-5 - - -EXTRA_GATHER_CASES = [ - ("scalar16", torch.float16, torch.int64), - ("scalar16", torch.bfloat16, torch.int32), - ("scalar16", torch.bfloat16, torch.int64), - ("complex64", torch.complex64, torch.int32), - ("complex64", torch.complex64, torch.int64), -] -EXTRA_GATHER_CASE_IDS = [ - "half_i64", - "bf16_i32", - "bf16_i64", - "c64_i32", - "c64_i64", -] - - @pytest.mark.gather @pytest.mark.parametrize("dense_size, nnz", GATHER_SCATTER_SHAPES) -@pytest.mark.parametrize("dtype", FLOAT_DTYPES, ids=FLOAT_DTYPE_IDS) +@pytest.mark.parametrize("dtype_name,dtype", GATHER_DTYPE_CASES, ids=GATHER_DTYPE_IDS) @pytest.mark.parametrize("index_dtype", INDEX_DTYPES, ids=INDEX_DTYPE_IDS) -def test_gather_matches_indexing(dense_size, nnz, dtype, index_dtype): - _skip_unavailable_dtype(str(dtype).replace("torch.", ""), dtype) +def test_gather_matches_indexing(dense_size, nnz, dtype_name, dtype, index_dtype): + _skip_unavailable_dtype(dtype_name, dtype) device = torch.device("cuda") nnz = min(nnz, dense_size) - dense = torch.randn(dense_size, dtype=dtype, device=device) + dense = _build_random_values(dense_size, dtype, device) indices = torch.randperm(dense_size, device=device)[:nnz].to(index_dtype) ref = dense[indices.to(torch.int64)] got = flagsparse_gather(dense, indices) @@ -245,133 +190,3 @@ def fake_launch(dense_values, sparse_values, kernel_indices, nnz, block_size=102 dtype_policy="auto", index_fallback_policy="strict", ) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -@pytest.mark.parametrize("backend", ["torch", "cupy"]) -@pytest.mark.parametrize( - "layout,value_dtype,index_dtype", - EXTRA_GATHER_CASES, - ids=EXTRA_GATHER_CASE_IDS, -) -def test_gather_cupy_extra_dtypes_match_torch_and_cusparse( - layout, - value_dtype, - index_dtype, - backend, -): - if value_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported(): - pytest.skip("bfloat16 not supported on this GPU") - - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - - dense_t = _build_extra_gather_dense(layout, value_dtype, dense_size, device) - indices_t = torch.arange(nnz, device=device, dtype=index_dtype) * 17 % dense_size - reference = dense_t.index_select(0, indices_t.to(torch.int64)) - - dense_in = _to_backend_tensor(dense_t, backend) - indices_in = _to_backend_tensor(indices_t, backend) - - got = flagsparse_gather_cupy(dense_in, indices_in) - cusparse_values, _, _ = cusparse_spmv_gather_cupy(dense_in, indices_in) - - got_t = _as_torch_tensor(got) - cusparse_t = _as_torch_tensor(cusparse_values) - - atol, rtol = _extra_gather_tolerance(value_dtype) - assert torch.allclose(got_t, reference, atol=atol, rtol=rtol) - assert torch.allclose(cusparse_t, reference, atol=atol, rtol=rtol) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_rejects_bfloat16_pair_layout(): - if not torch.cuda.is_bf16_supported(): - pytest.skip("bfloat16 not supported on this GPU") - - device = torch.device("cuda") - dense_size = 4096 - nnz = 256 - dense = torch.randn(dense_size, 2, dtype=torch.bfloat16, device=device) - indices = torch.arange(nnz, device=device, dtype=torch.int32) * 17 % dense_size - - with pytest.raises(TypeError, match="Unsupported gather input format"): - flagsparse_gather_cupy(dense, indices) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -@pytest.mark.parametrize("backend", ["torch", "cupy"]) -def test_gather_cupy_same_backend_out_float16_i64(backend): - device = torch.device("cuda") - dense_size = 65536 - nnz = 4096 - dense_t = torch.randn(dense_size, dtype=torch.float16, device=device) - indices_t = torch.arange(nnz, device=device, dtype=torch.int64) * 17 % dense_size - reference = dense_t.index_select(0, indices_t) - - dense_in = _to_backend_tensor(dense_t, backend) - indices_in = _to_backend_tensor(indices_t, backend) - out = _to_backend_tensor(torch.empty_like(reference), backend) - result = flagsparse_gather_cupy(dense_in, indices_in, out=out) - - assert result is out - assert torch.allclose(_as_torch_tensor(out), reference, atol=5e-3, rtol=5e-3) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_int64_auto_fallback_to_int32(monkeypatch): - device = torch.device("cuda") - dense_size = 257 - nnz = 129 - dense = torch.randn(dense_size, dtype=torch.float16, device=device) - indices = torch.randperm(dense_size, device=device)[:nnz].to(torch.int64) - ref = dense.index_select(0, indices.to(torch.int64)) - - original_launch = gather_scatter_ops._launch_cupy_gather_kernel - state = {"forced_once": False} - - def fake_launch(gather_kernel, dense_raw_cp, indices_cp, out_raw_cp, nnz): - if not state["forced_once"]: - state["forced_once"] = True - raise RuntimeError("forced int64 launch failure") - return original_launch(gather_kernel, dense_raw_cp, indices_cp, out_raw_cp, nnz) - - monkeypatch.setattr(gather_scatter_ops, "_launch_cupy_gather_kernel", fake_launch) - - got, meta = flagsparse_gather_cupy( - dense, - indices, - index_fallback_policy="auto", - return_metadata=True, - ) - assert state["forced_once"] - assert meta["index_fallback_applied"] - assert meta["kernel_index_dtype"] == "int32" - assert torch.allclose(got, ref, atol=5e-3, rtol=5e-3) - - -@pytest.mark.gather -@pytest.mark.skipif(cp is None, reason="CuPy required") -def test_gather_cupy_int64_strict_no_fallback(monkeypatch): - device = torch.device("cuda") - dense_size = 257 - nnz = 129 - dense = torch.randn(dense_size, dtype=torch.float16, device=device) - indices = torch.randperm(dense_size, device=device)[:nnz].to(torch.int64) - - def fake_launch(gather_kernel, dense_raw_cp, indices_cp, out_raw_cp, nnz): - raise RuntimeError("forced int64 launch failure") - - monkeypatch.setattr(gather_scatter_ops, "_launch_cupy_gather_kernel", fake_launch) - - with pytest.raises(RuntimeError, match="CuPy gather failed for index dtype"): - flagsparse_gather_cupy( - dense, - indices, - index_fallback_policy="strict", - ) diff --git a/tests/test_gather.py b/tests/test_gather.py index 2bc2dd1..9508a02 100644 --- a/tests/test_gather.py +++ b/tests/test_gather.py @@ -163,12 +163,12 @@ def _collect_samples(case_id, expected, flagsparse_out, limit): def _dtype_mode(value_dtype_req): + _ = value_dtype_req return "gather_triton" def _select_mode(value_dtype_req, index_dtype): - # The required gather coverage is the full 6 value dtypes x 2 index dtypes - # matrix, so benchmark the primary Triton gather path for every combo. + _ = index_dtype return _dtype_mode(value_dtype_req) @@ -221,15 +221,9 @@ def _check_dtype_supported(value_dtype_req): raise RuntimeError("bfloat16 not supported on this GPU") -def _is_supported_extra_gather_combo(value_dtype_req, index_dtype): - # Required extra gather combos only: - # Half+Int32/Int64, Bfloat16+Int32/Int64, Complex64+Int32/Int64 - if value_dtype_req == "float16": - return index_dtype in (torch.int32, torch.int64) - if value_dtype_req in ("bfloat16", "complex64"): - return index_dtype in (torch.int32, torch.int64) - # Original gather path dtypes keep original behavior. - return True +def _is_supported_gather_combo(index_dtype): + # Required gather coverage is the full 6 value dtypes x 2 index dtypes matrix. + return index_dtype in (torch.int32, torch.int64) def _build_indices(dense_size, nnz, index_dtype, device): @@ -253,28 +247,13 @@ def _benchmark_gather_case( expected = dense_vector.index_select(0, indices.to(torch.int64)) mode = _select_mode(value_dtype_req, index_dtype) - if mode == "gather_cupy": - preview_output, gather_meta = ast.flagsparse_gather_cupy( - dense_vector, - indices, - index_fallback_policy=index_fallback_policy, - return_metadata=True, - ) - _ = preview_output - flagsparse_op = lambda: ast.flagsparse_gather_cupy( - dense_vector, - indices, - index_fallback_policy=index_fallback_policy, - ) - cusparse_op = lambda: ast.cusparse_spmv_gather_cupy(dense_vector, indices)[0] - else: - gather_meta = { - "index_fallback_applied": False, - "index_fallback_reason": None, - "kernel_index_dtype": str(index_dtype).replace("torch.", ""), - } - flagsparse_op = lambda: ast.flagsparse_gather(dense_vector, indices) - cusparse_op = lambda: ast.cusparse_spmv_gather(dense_vector, indices)[0] + gather_meta = { + "index_fallback_applied": False, + "index_fallback_reason": None, + "kernel_index_dtype": str(index_dtype).replace("torch.", ""), + } + flagsparse_op = lambda: ast.flagsparse_gather(dense_vector, indices) + cusparse_op = lambda: ast.cusparse_spmv_gather(dense_vector, indices)[0] pytorch_op = lambda: dense_vector.index_select(0, indices.to(torch.int64)) pytorch_values, pytorch_ms = _bench_cuda_op(pytorch_op, warmup=warmup, iters=iters) @@ -401,7 +380,7 @@ def run_cli(args): for value_dtype in value_dtype_tokens: for index_name, index_dtype in index_dtype_pairs: - if not _is_supported_extra_gather_combo(value_dtype, index_dtype): + if not _is_supported_gather_combo(index_dtype): continue for dense_size, nnz in work_cases: dense_size = int(dense_size) @@ -477,7 +456,7 @@ def run_cli(args): "index_dtype": index_name, "dense_size": dense_size, "nnz": nnz, - "mode": _dtype_mode(value_dtype), + "mode": _select_mode(value_dtype, index_dtype), "index_fallback_policy": args.index_fallback_policy, "index_fallback_applied": False, "triton_ms": None, From fe5786a8904064df99aa39d1a9ebb4cd3c606dcb Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 00:17:17 +0800 Subject: [PATCH 08/22] Add files from flagsparse_new --- ops_support.py | 395 +++++++++++++++++++++ ops_support_sort_check.csv | 177 +++++++++ run_flagsparse_pytest.py | 264 ++++++++++++++ tests/pytest/test_spgemm_sddmm_accuracy.py | 100 ++++++ tests/pytest/test_spmm_coo_accuracy.py | 45 +++ tests/pytest/test_spsm_accuracy.py | 68 ++++ 6 files changed, 1049 insertions(+) create mode 100644 ops_support.py create mode 100644 ops_support_sort_check.csv create mode 100644 run_flagsparse_pytest.py create mode 100644 tests/pytest/test_spgemm_sddmm_accuracy.py create mode 100644 tests/pytest/test_spmm_coo_accuracy.py create mode 100644 tests/pytest/test_spsm_accuracy.py diff --git a/ops_support.py b/ops_support.py new file mode 100644 index 0000000..62be4fb --- /dev/null +++ b/ops_support.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +"""Export the declared FlagSparse sparse-operator support matrix to CSV. + +The script is intentionally static: it does not import torch/triton/cupy or +flagsparse, and it never launches kernels. It reads source files and reports the +support declared by constants such as SUPPORTED_*_VALUE_DTYPES. +""" + +from __future__ import annotations + +import argparse +import ast +import csv +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterable + + +CSV_FIELDS = ( + "operator", + "format", + "index_dtype", + "value_dtype", + "op", + "route", + "status", +) + +DEFAULT_VALUE_DTYPES = ("float16", "bfloat16", "float32", "float64", "complex64", "complex128") +DEFAULT_INDEX_DTYPES = ("int32", "int64") +NA = "N/A" +VALUE_DTYPE_ORDER = { + "float16": 0, + "bfloat16": 1, + "float32": 2, + "float64": 3, + "complex32": 4, + "complex64": 5, + "complex128": 6, +} +INDEX_DTYPE_ORDER = {"int32": 0, "int64": 1} +OP_ORDER = {"non": 0, "trans": 1, "conj": 2} + + +@dataclass(frozen=True) +class ApiSpec: + operator: str + api: str + module: str + fmt: str + route: str + value_const: str | None = None + index_const: str | None = None + values: tuple[str, ...] | None = None + indices: tuple[str, ...] | None = None + ops: tuple[str, ...] | str | None = None + notes: str = "" + + +class SourceModule: + """Small AST-backed view of one sparse operation module.""" + + def __init__(self, path: Path, shared: dict[str, Any] | None = None): + self.path = path + self.name = path.stem + self.tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + self.assignments: dict[str, ast.AST] = {} + self.functions: set[str] = set() + self.shared = shared or {} + self._collect() + + def _collect(self) -> None: + for node in self.tree.body: + if isinstance(node, (ast.Assign, ast.AnnAssign)): + targets = node.targets if isinstance(node, ast.Assign) else [node.target] + for target in targets: + if isinstance(target, ast.Name): + self.assignments[target.id] = node.value + elif isinstance(node, ast.FunctionDef): + self.functions.add(node.name) + + def get(self, name: str) -> Any: + if name in self.assignments: + return self._eval(self.assignments[name]) + return self.shared.get(name) + + def _eval(self, node: ast.AST) -> Any: + if isinstance(node, ast.Tuple): + return tuple(self._flatten(el) for el in node.elts) + if isinstance(node, ast.List): + return tuple(self._flatten(el) for el in node.elts) + if isinstance(node, ast.Set): + return tuple(self._flatten(el) for el in node.elts) + if isinstance(node, ast.Dict): + return {self._eval(k): self._eval(v) for k, v in zip(node.keys, node.values) if k is not None} + if isinstance(node, ast.Starred): + return self._eval(node.value) + if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and node.value.id == "torch": + return node.attr + if isinstance(node, ast.Name): + if node.id in self.assignments: + return self._eval(self.assignments[node.id]) + if node.id in self.shared: + return self.shared[node.id] + return node.id + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.Call): + return self._eval_call(node) + if isinstance(node, ast.IfExp): + # Optional complex32/chalf support is represented as an if-expression. + # Report the declared capability conservatively, independent of the + # local torch build where this script is executed. + return self._eval(node.body) + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add): + left = self._as_tuple(self._eval(node.left)) + right = self._as_tuple(self._eval(node.right)) + return left + right + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + value = self._eval(node.operand) + return -value if isinstance(value, (int, float)) else value + return None + + def _eval_call(self, node: ast.Call) -> Any: + if isinstance(node.func, ast.Name): + if node.func.id == "_torch_complex32_dtype": + return "complex32" + if node.func.id == "tuple" and node.args: + return self._as_tuple(self._eval(node.args[0])) + return None + + def _flatten(self, node: ast.AST) -> Any: + value = self._eval(node) + if isinstance(node, ast.Starred): + return value + return value + + @staticmethod + def _as_tuple(value: Any) -> tuple[Any, ...]: + if value is None: + return () + if isinstance(value, tuple): + return value + if isinstance(value, list): + return tuple(value) + return (value,) + + +def normalize_dtype_values(value: Any) -> tuple[str, ...]: + raw = flatten(value) + dtypes: list[str] = [] + for item in raw: + if item is None: + continue + token = str(item).replace("torch.", "") + if token == "chalf": + token = "complex32" + if token not in dtypes: + dtypes.append(token) + return tuple(dtypes) + + +def flatten(value: Any) -> tuple[Any, ...]: + if value is None: + return () + if isinstance(value, (tuple, list, set)): + out: list[Any] = [] + for item in value: + out.extend(flatten(item)) + return tuple(out) + return (value,) + + +def discover_modules(src_root: Path) -> dict[str, SourceModule]: + common = SourceModule(src_root / "_common.py") + common_values = normalize_dtype_values(common.get("SUPPORTED_VALUE_DTYPES")) + if "complex64" not in common_values: + # _common.py builds this list with append/extend after initial assignment. + # Static expression evaluation intentionally avoids executing that code, + # so keep this fallback aligned with the declared global support set. + common_values = DEFAULT_VALUE_DTYPES + ("complex32",) + shared = { + "SUPPORTED_VALUE_DTYPES": common_values, + "SUPPORTED_INDEX_DTYPES": normalize_dtype_values(common.get("SUPPORTED_INDEX_DTYPES")), + } + modules = {"_common": common} + for path in sorted(src_root.glob("*.py")): + if path.name == "_common.py": + continue + modules[path.stem] = SourceModule(path, shared=shared) + return modules + + +def collect_public_apis(src_root: Path) -> set[str]: + init_path = src_root / "__init__.py" + if not init_path.exists(): + return set() + init_module = SourceModule(init_path) + values = normalize_dtype_values(init_module.get("__all__")) + return set(values) + + +def op_names(module: SourceModule) -> tuple[str, ...]: + names = module.get("SPMV_OP_NAMES") + if isinstance(names, dict): + values = [str(v) for _, v in sorted(names.items(), key=lambda item: item[0])] + if values: + return tuple(values) + return ("non", "trans", "conj") + + +def normalize_op_label(op: Any) -> str: + token = str(op).strip().lower().replace("-", "_") + token = token.replace(" ", "_") + if token in ("n/a", "na", "none"): + return NA + if token in ("n", "non", "non_trans", "non_transpose"): + return "non" + if token in ("t", "trans", "transpose"): + return "trans" + if token in ("c", "conj", "conj_trans", "conj_transpose", "conjugate_transpose"): + return "conj" + return token + + +def registry(modules: dict[str, SourceModule]) -> tuple[ApiSpec, ...]: + spmv_ops = op_names(modules["spmv_csr"]) if "spmv_csr" in modules else ("non", "trans", "conj") + spmm_values = normalize_dtype_values(modules["spmm_csr"].get("SUPPORTED_SPMM_VALUE_DTYPES")) if "spmm_csr" in modules else None + return ( + ApiSpec("gather", "flagsparse_gather", "gather_scatter", "index", "triton", values=DEFAULT_VALUE_DTYPES + ("complex32",), indices=DEFAULT_INDEX_DTYPES), + ApiSpec("scatter", "flagsparse_scatter", "gather_scatter", "index", "triton", value_const="SUPPORTED_SCATTER_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES"), + ApiSpec("spmv", "flagsparse_spmv_csr", "spmv_csr", "CSR", "triton", value_const="SUPPORTED_SPMV_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=spmv_ops, notes="op supports non/trans/conj; conj on real dtypes is transpose-equivalent"), + ApiSpec("spmv", "flagsparse_spmv_coo", "spmv_coo", "COO", "triton", values=("float32", "float64"), indices=DEFAULT_INDEX_DTYPES, ops=("non",), notes="COO path casts kernel indices to int32 after range check"), + ApiSpec("spmv", "flagsparse_spmv_coo_tocsr", "spmv_csr", "COO->CSR", "triton", value_const="SUPPORTED_SPMV_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",), notes="COO input is converted to CSR before compute"), + ApiSpec("spmm", "flagsparse_spmm_csr", "spmm_csr", "CSR", "triton", value_const="SUPPORTED_SPMM_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",)), + ApiSpec("spmm", "flagsparse_spmm_csr_opt", "spmm_csr", "CSR", "triton_opt", value_const="SUPPORTED_SPMM_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",)), + ApiSpec("spmm", "flagsparse_spmm_coo", "spmm_coo", "COO", "triton", values=spmm_values, index_const="SUPPORTED_INDEX_DTYPES", ops=("non",), notes="COO SpMM reuses CSR SpMM dtype declaration"), + ApiSpec("spgemm", "flagsparse_spgemm_csr", "spgemm_csr", "CSR", "triton", value_const="SUPPORTED_SPGEMM_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",)), + ApiSpec("sddmm", "flagsparse_sddmm_csr", "sddmm_csr", "CSR", "triton", value_const="SUPPORTED_SDDMM_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",)), + ApiSpec("spsv", "flagsparse_spsv_csr", "spsv", "CSR", "triton", value_const="SUPPORTED_SPSV_VALUE_DTYPES", index_const="SUPPORTED_SPSV_INDEX_DTYPES", ops=("NON_TRANS", "TRANS"), notes="TRANS support is narrower than NON_TRANS; see combo constants"), + ApiSpec("spsv", "flagsparse_spsv_coo", "spsv", "COO", "triton", value_const="SUPPORTED_SPSV_VALUE_DTYPES", index_const="SUPPORTED_SPSV_INDEX_DTYPES", ops=("NON_TRANS", "TRANS"), notes="TRANS support is narrower than NON_TRANS; see combo constants"), + ApiSpec("spsm", "flagsparse_spsm_csr", "spsm", "CSR", "triton", value_const="SUPPORTED_SPSM_VALUE_DTYPES", index_const="SUPPORTED_SPSM_INDEX_DTYPES", ops=("NON_TRANS",), notes="opA/opB must both be NON_TRANS; row-major dense layout only"), + ApiSpec("spsm", "flagsparse_spsm_coo", "spsm", "COO", "triton", value_const="SUPPORTED_SPSM_VALUE_DTYPES", index_const="SUPPORTED_SPSM_INDEX_DTYPES", ops=("NON_TRANS",), notes="opA/opB must both be NON_TRANS; row-major dense layout only"), + ) + + +def rows_for_spec(spec: ApiSpec, modules: dict[str, SourceModule], public_apis: set[str], src_root: Path) -> list[dict[str, str]]: + module = modules.get(spec.module) + notes = [spec.notes] if spec.notes else [] + status = "SUPPORTED" + if module is None: + return [row(spec, NA, NA, NA, "PARTIAL")] + + values = spec.values + if values is None and spec.value_const: + values = normalize_dtype_values(module.get(spec.value_const)) + indices = spec.indices + if indices is None and spec.index_const: + indices = normalize_dtype_values(module.get(spec.index_const)) + if not values: + values = (NA,) + status = "PARTIAL" + notes.append(f"value dtype constant {spec.value_const or ''} not found") + if not indices: + indices = (NA,) + status = "PARTIAL" + notes.append(f"index dtype constant {spec.index_const or ''} not found") + if spec.api not in module.functions and spec.api not in public_apis: + status = "PARTIAL" + notes.append(f"api {spec.api} not found in source exports/functions") + + ops = spec.ops + if ops is None: + ops_tuple = ("non",) + elif isinstance(ops, str): + ops_tuple = (ops,) + else: + ops_tuple = ops + + return [ + row(spec, value_dtype, index_dtype, op, status) + for value_dtype in values + for index_dtype in indices + for op in ops_tuple + ] + + +def row(spec: ApiSpec, value_dtype: str, index_dtype: str, op: str, status: str) -> dict[str, str]: + return { + "operator": spec.operator, + "format": spec.fmt, + "index_dtype": str(index_dtype), + "value_dtype": str(value_dtype), + "op": normalize_op_label(op), + "route": spec.route, + "status": status, + } + + +def discovered_unmapped_rows(modules: dict[str, SourceModule], specs: Iterable[ApiSpec], src_root: Path) -> list[dict[str, str]]: + mapped = {(spec.module, spec.value_const) for spec in specs if spec.value_const} + mapped |= {(spec.module, spec.index_const) for spec in specs if spec.index_const} + out: list[dict[str, str]] = [] + for module_name, module in sorted(modules.items()): + if module_name in {"_common", "__init__", "benchmarks"}: + continue + for const_name in sorted(module.assignments): + if not (const_name.startswith("SUPPORTED_") and const_name.endswith("_DTYPES")): + continue + if (module_name, const_name) in mapped: + continue + values = normalize_dtype_values(module.get(const_name)) + dtype_kind = "index_dtype" if "INDEX" in const_name else "value_dtype" + for dtype in values or (NA,): + data = { + "operator": module_name, + "format": NA, + "value_dtype": dtype if dtype_kind == "value_dtype" else NA, + "index_dtype": dtype if dtype_kind == "index_dtype" else NA, + "op": NA, + "route": NA, + "status": "DISCOVERED_UNMAPPED", + } + out.append(data) + return out + + +def _ordered_value(mapping: dict[str, int], value: str) -> tuple[int, str]: + return (mapping.get(value, len(mapping)), value) + + +def _sort_rows(rows: list[dict[str, str]]) -> list[dict[str, str]]: + return sorted( + rows, + key=lambda item: ( + item["operator"], + item["format"], + _ordered_value(INDEX_DTYPE_ORDER, item["index_dtype"]), + _ordered_value(VALUE_DTYPE_ORDER, item["value_dtype"]), + _ordered_value(OP_ORDER, item["op"]), + item["route"], + item["status"], + ), + ) + + +def build_rows(src_root: Path) -> list[dict[str, str]]: + modules = discover_modules(src_root) + public_apis = collect_public_apis(src_root) + specs = registry(modules) + rows: list[dict[str, str]] = [] + for spec in specs: + rows.extend(rows_for_spec(spec, modules, public_apis, src_root)) + rows.extend(discovered_unmapped_rows(modules, specs, src_root)) + return _sort_rows(rows) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + script_dir = Path(__file__).resolve().parent + parser.add_argument( + "--src-root", + type=Path, + default=script_dir / "src" / "flagsparse" / "sparse_operations", + help="Path to src/flagsparse/sparse_operations (default: project-local path).", + ) + parser.add_argument( + "--output", + type=Path, + default=script_dir / "ops_support.csv", + help="CSV output path (default: ./ops_support.csv).", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + src_root = args.src_root.resolve() + output = args.output.resolve() + if not src_root.exists(): + raise FileNotFoundError(f"sparse operations source root not found: {src_root}") + + rows = build_rows(src_root) + output.parent.mkdir(parents=True, exist_ok=True) + with output.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=CSV_FIELDS) + writer.writeheader() + writer.writerows(rows) + + print(f"Wrote {len(rows)} support rows to {output}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ops_support_sort_check.csv b/ops_support_sort_check.csv new file mode 100644 index 0000000..61ddfaa --- /dev/null +++ b/ops_support_sort_check.csv @@ -0,0 +1,177 @@ +operator,format,index_dtype,value_dtype,op,route,status +gather,index,int32,float16,non,triton,SUPPORTED +gather,index,int32,bfloat16,non,triton,SUPPORTED +gather,index,int32,float32,non,triton,SUPPORTED +gather,index,int32,float64,non,triton,SUPPORTED +gather,index,int32,complex32,non,triton,SUPPORTED +gather,index,int32,complex64,non,triton,SUPPORTED +gather,index,int32,complex128,non,triton,SUPPORTED +gather,index,int64,float16,non,triton,SUPPORTED +gather,index,int64,bfloat16,non,triton,SUPPORTED +gather,index,int64,float32,non,triton,SUPPORTED +gather,index,int64,float64,non,triton,SUPPORTED +gather,index,int64,complex32,non,triton,SUPPORTED +gather,index,int64,complex64,non,triton,SUPPORTED +gather,index,int64,complex128,non,triton,SUPPORTED +scatter,index,int32,float16,non,triton,SUPPORTED +scatter,index,int32,bfloat16,non,triton,SUPPORTED +scatter,index,int32,float32,non,triton,SUPPORTED +scatter,index,int32,float64,non,triton,SUPPORTED +scatter,index,int32,complex64,non,triton,SUPPORTED +scatter,index,int32,complex128,non,triton,SUPPORTED +scatter,index,int64,float16,non,triton,SUPPORTED +scatter,index,int64,bfloat16,non,triton,SUPPORTED +scatter,index,int64,float32,non,triton,SUPPORTED +scatter,index,int64,float64,non,triton,SUPPORTED +scatter,index,int64,complex64,non,triton,SUPPORTED +scatter,index,int64,complex128,non,triton,SUPPORTED +sddmm,CSR,int32,float32,non,triton,SUPPORTED +sddmm,CSR,int32,float64,non,triton,SUPPORTED +sddmm,CSR,int64,float32,non,triton,SUPPORTED +sddmm,CSR,int64,float64,non,triton,SUPPORTED +spgemm,CSR,int32,float32,non,triton,SUPPORTED +spgemm,CSR,int32,float64,non,triton,SUPPORTED +spgemm,CSR,int64,float32,non,triton,SUPPORTED +spgemm,CSR,int64,float64,non,triton,SUPPORTED +spmm,COO,int32,float16,non,triton,SUPPORTED +spmm,COO,int32,bfloat16,non,triton,SUPPORTED +spmm,COO,int32,float32,non,triton,SUPPORTED +spmm,COO,int32,float64,non,triton,SUPPORTED +spmm,COO,int32,complex32,non,triton,SUPPORTED +spmm,COO,int32,complex64,non,triton,SUPPORTED +spmm,COO,int32,complex128,non,triton,SUPPORTED +spmm,COO,int64,float16,non,triton,SUPPORTED +spmm,COO,int64,bfloat16,non,triton,SUPPORTED +spmm,COO,int64,float32,non,triton,SUPPORTED +spmm,COO,int64,float64,non,triton,SUPPORTED +spmm,COO,int64,complex32,non,triton,SUPPORTED +spmm,COO,int64,complex64,non,triton,SUPPORTED +spmm,COO,int64,complex128,non,triton,SUPPORTED +spmm,CSR,int32,float16,non,triton,SUPPORTED +spmm,CSR,int32,float16,non,triton_opt,SUPPORTED +spmm,CSR,int32,bfloat16,non,triton,SUPPORTED +spmm,CSR,int32,bfloat16,non,triton_opt,SUPPORTED +spmm,CSR,int32,float32,non,triton,SUPPORTED +spmm,CSR,int32,float32,non,triton_opt,SUPPORTED +spmm,CSR,int32,float64,non,triton,SUPPORTED +spmm,CSR,int32,float64,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex32,non,triton,SUPPORTED +spmm,CSR,int32,complex32,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex64,non,triton,SUPPORTED +spmm,CSR,int32,complex64,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex128,non,triton,SUPPORTED +spmm,CSR,int32,complex128,non,triton_opt,SUPPORTED +spmm,CSR,int64,float16,non,triton,SUPPORTED +spmm,CSR,int64,float16,non,triton_opt,SUPPORTED +spmm,CSR,int64,bfloat16,non,triton,SUPPORTED +spmm,CSR,int64,bfloat16,non,triton_opt,SUPPORTED +spmm,CSR,int64,float32,non,triton,SUPPORTED +spmm,CSR,int64,float32,non,triton_opt,SUPPORTED +spmm,CSR,int64,float64,non,triton,SUPPORTED +spmm,CSR,int64,float64,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex32,non,triton,SUPPORTED +spmm,CSR,int64,complex32,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex64,non,triton,SUPPORTED +spmm,CSR,int64,complex64,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex128,non,triton,SUPPORTED +spmm,CSR,int64,complex128,non,triton_opt,SUPPORTED +spmv,COO,int32,float32,non,triton,SUPPORTED +spmv,COO,int32,float64,non,triton,SUPPORTED +spmv,COO,int64,float32,non,triton,SUPPORTED +spmv,COO,int64,float64,non,triton,SUPPORTED +spmv,COO->CSR,int32,float16,non,triton,SUPPORTED +spmv,COO->CSR,int32,bfloat16,non,triton,SUPPORTED +spmv,COO->CSR,int32,float32,non,triton,SUPPORTED +spmv,COO->CSR,int32,float64,non,triton,SUPPORTED +spmv,COO->CSR,int32,complex64,non,triton,SUPPORTED +spmv,COO->CSR,int32,complex128,non,triton,SUPPORTED +spmv,COO->CSR,int64,float16,non,triton,SUPPORTED +spmv,COO->CSR,int64,bfloat16,non,triton,SUPPORTED +spmv,COO->CSR,int64,float32,non,triton,SUPPORTED +spmv,COO->CSR,int64,float64,non,triton,SUPPORTED +spmv,COO->CSR,int64,complex64,non,triton,SUPPORTED +spmv,COO->CSR,int64,complex128,non,triton,SUPPORTED +spmv,CSR,int32,float16,non,triton,SUPPORTED +spmv,CSR,int32,float16,trans,triton,SUPPORTED +spmv,CSR,int32,float16,conj,triton,SUPPORTED +spmv,CSR,int32,bfloat16,non,triton,SUPPORTED +spmv,CSR,int32,bfloat16,trans,triton,SUPPORTED +spmv,CSR,int32,bfloat16,conj,triton,SUPPORTED +spmv,CSR,int32,float32,non,triton,SUPPORTED +spmv,CSR,int32,float32,trans,triton,SUPPORTED +spmv,CSR,int32,float32,conj,triton,SUPPORTED +spmv,CSR,int32,float64,non,triton,SUPPORTED +spmv,CSR,int32,float64,trans,triton,SUPPORTED +spmv,CSR,int32,float64,conj,triton,SUPPORTED +spmv,CSR,int32,complex64,non,triton,SUPPORTED +spmv,CSR,int32,complex64,trans,triton,SUPPORTED +spmv,CSR,int32,complex64,conj,triton,SUPPORTED +spmv,CSR,int32,complex128,non,triton,SUPPORTED +spmv,CSR,int32,complex128,trans,triton,SUPPORTED +spmv,CSR,int32,complex128,conj,triton,SUPPORTED +spmv,CSR,int64,float16,non,triton,SUPPORTED +spmv,CSR,int64,float16,trans,triton,SUPPORTED +spmv,CSR,int64,float16,conj,triton,SUPPORTED +spmv,CSR,int64,bfloat16,non,triton,SUPPORTED +spmv,CSR,int64,bfloat16,trans,triton,SUPPORTED +spmv,CSR,int64,bfloat16,conj,triton,SUPPORTED +spmv,CSR,int64,float32,non,triton,SUPPORTED +spmv,CSR,int64,float32,trans,triton,SUPPORTED +spmv,CSR,int64,float32,conj,triton,SUPPORTED +spmv,CSR,int64,float64,non,triton,SUPPORTED +spmv,CSR,int64,float64,trans,triton,SUPPORTED +spmv,CSR,int64,float64,conj,triton,SUPPORTED +spmv,CSR,int64,complex64,non,triton,SUPPORTED +spmv,CSR,int64,complex64,trans,triton,SUPPORTED +spmv,CSR,int64,complex64,conj,triton,SUPPORTED +spmv,CSR,int64,complex128,non,triton,SUPPORTED +spmv,CSR,int64,complex128,trans,triton,SUPPORTED +spmv,CSR,int64,complex128,conj,triton,SUPPORTED +spsm,COO,int32,float32,non,triton,SUPPORTED +spsm,COO,int32,float64,non,triton,SUPPORTED +spsm,COO,int64,float32,non,triton,SUPPORTED +spsm,COO,int64,float64,non,triton,SUPPORTED +spsm,CSR,int32,float32,non,triton,SUPPORTED +spsm,CSR,int32,float64,non,triton,SUPPORTED +spsm,CSR,int64,float32,non,triton,SUPPORTED +spsm,CSR,int64,float64,non,triton,SUPPORTED +spsv,COO,int32,bfloat16,non,triton,SUPPORTED +spsv,COO,int32,bfloat16,trans,triton,SUPPORTED +spsv,COO,int32,float32,non,triton,SUPPORTED +spsv,COO,int32,float32,trans,triton,SUPPORTED +spsv,COO,int32,float64,non,triton,SUPPORTED +spsv,COO,int32,float64,trans,triton,SUPPORTED +spsv,COO,int32,complex32,non,triton,SUPPORTED +spsv,COO,int32,complex32,trans,triton,SUPPORTED +spsv,COO,int32,complex64,non,triton,SUPPORTED +spsv,COO,int32,complex64,trans,triton,SUPPORTED +spsv,COO,int64,bfloat16,non,triton,SUPPORTED +spsv,COO,int64,bfloat16,trans,triton,SUPPORTED +spsv,COO,int64,float32,non,triton,SUPPORTED +spsv,COO,int64,float32,trans,triton,SUPPORTED +spsv,COO,int64,float64,non,triton,SUPPORTED +spsv,COO,int64,float64,trans,triton,SUPPORTED +spsv,COO,int64,complex32,non,triton,SUPPORTED +spsv,COO,int64,complex32,trans,triton,SUPPORTED +spsv,COO,int64,complex64,non,triton,SUPPORTED +spsv,COO,int64,complex64,trans,triton,SUPPORTED +spsv,CSR,int32,bfloat16,non,triton,SUPPORTED +spsv,CSR,int32,bfloat16,trans,triton,SUPPORTED +spsv,CSR,int32,float32,non,triton,SUPPORTED +spsv,CSR,int32,float32,trans,triton,SUPPORTED +spsv,CSR,int32,float64,non,triton,SUPPORTED +spsv,CSR,int32,float64,trans,triton,SUPPORTED +spsv,CSR,int32,complex32,non,triton,SUPPORTED +spsv,CSR,int32,complex32,trans,triton,SUPPORTED +spsv,CSR,int32,complex64,non,triton,SUPPORTED +spsv,CSR,int32,complex64,trans,triton,SUPPORTED +spsv,CSR,int64,bfloat16,non,triton,SUPPORTED +spsv,CSR,int64,bfloat16,trans,triton,SUPPORTED +spsv,CSR,int64,float32,non,triton,SUPPORTED +spsv,CSR,int64,float32,trans,triton,SUPPORTED +spsv,CSR,int64,float64,non,triton,SUPPORTED +spsv,CSR,int64,float64,trans,triton,SUPPORTED +spsv,CSR,int64,complex32,non,triton,SUPPORTED +spsv,CSR,int64,complex32,trans,triton,SUPPORTED +spsv,CSR,int64,complex64,non,triton,SUPPORTED +spsv,CSR,int64,complex64,trans,triton,SUPPORTED diff --git a/run_flagsparse_pytest.py b/run_flagsparse_pytest.py new file mode 100644 index 0000000..489de8e --- /dev/null +++ b/run_flagsparse_pytest.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +"""Run FlagSparse pytest accuracy suites per operator and summarize results.""" + +from __future__ import annotations + +import argparse +import csv +import datetime as _dt +import json +import os +import re +import shlex +import subprocess +import sys +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +try: + from openpyxl import Workbook +except Exception: + Workbook = None + + +DEFAULT_OPS = ( + "gather", + "scatter", + "spmv_csr", + "spmv_coo", + "spmv_coo_tocsr", + "spmm_csr", + "spmm_csr_opt", + "spmm_coo", + "spsv_csr", + "spsv_coo", + "spsm_csr", + "spsm_coo", + "spgemm_csr", + "sddmm_csr", +) + +ANSI_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") +SUMMARY_RE = re.compile(r"(\d+)\s+([A-Za-z_]+)") +SUMMARY_LOCK = threading.Lock() + + +def now_ts() -> str: + return _dt.datetime.now().strftime("%Y%m%d_%H%M%S") + + +def ensure_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + +def parse_pytest_summary(text: str) -> dict[str, int]: + clean = ANSI_RE.sub("", text) + counts = {"passed": 0, "failed": 0, "skipped": 0, "errors": 0} + for match in SUMMARY_RE.finditer(clean): + key = match.group(2).lower() + if key in counts: + counts[key] = int(match.group(1)) + counts["total"] = counts["passed"] + counts["failed"] + counts["skipped"] + return counts + + +def status_from_counts(counts: dict[str, int], returncode: int) -> str: + if returncode not in (0, 5) and not any( + counts[key] for key in ("passed", "failed", "skipped", "errors") + ): + return "CRASH" + if counts["failed"] or counts["errors"] or returncode not in (0, 5): + return "FAIL" + if counts["passed"]: + return "PASS" + if counts["skipped"]: + return "SKIP" + return "NO_TESTS" + + +def read_ops(op_list: str | None, ops_arg: str | None) -> list[str]: + if ops_arg: + return [op.strip() for op in ops_arg.split(",") if op.strip()] + if op_list: + with open(op_list, encoding="utf-8") as handle: + return [ + line.strip() + for line in handle + if line.strip() and not line.lstrip().startswith("#") + ] + return list(DEFAULT_OPS) + + +def run_one_op( + project_root: Path, + op: str, + gpu_id: int, + mode: str, + results_dir: Path, + extra_pytest_args: list[str], +) -> dict[str, object]: + op_dir = results_dir / op + ensure_dir(op_dir) + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + cmd = [ + sys.executable, + "-m", + "pytest", + "tests/pytest", + "-m", + op, + "--mode", + mode, + "-vs", + "-p", + "no:cacheprovider", + *extra_pytest_args, + ] + proc = subprocess.run( + cmd, + cwd=str(project_root), + env=env, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + combined = (proc.stdout or "") + "\n" + (proc.stderr or "") + log_path = op_dir / "accuracy.log" + log_path.write_text(combined, encoding="utf-8") + + counts = parse_pytest_summary(combined) + return { + "operator": op, + "gpu": gpu_id, + "status": status_from_counts(counts, proc.returncode), + "returncode": proc.returncode, + "log_path": str(log_path), + **counts, + } + + +def run_gpu_ops( + project_root: Path, + gpu_id: int, + ops: list[str], + mode: str, + results_dir: Path, + extra_pytest_args: list[str], + results: list[dict[str, object]], +) -> None: + for op in ops: + result = run_one_op( + project_root, + op, + gpu_id, + mode, + results_dir, + extra_pytest_args, + ) + with SUMMARY_LOCK: + results.append(result) + write_summary(results, results_dir) + print( + f"[{result['status']}] {result['operator']} " + f"gpu={result['gpu']} passed={result['passed']} failed={result['failed']} " + f"skipped={result['skipped']} errors={result['errors']}" + ) + + +def write_summary(results: list[dict[str, object]], results_dir: Path) -> None: + ordered = sorted(results, key=lambda item: str(item["operator"])) + json_path = results_dir / "summary.json" + json_path.write_text(json.dumps(ordered, indent=2), encoding="utf-8") + + csv_path = results_dir / "summary.csv" + headers = [ + "operator", + "gpu", + "status", + "passed", + "failed", + "skipped", + "errors", + "total", + "returncode", + "log_path", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=headers) + writer.writeheader() + for row in ordered: + writer.writerow({key: row.get(key, "") for key in headers}) + + if Workbook is None: + return + wb = Workbook() + ws = wb.active + ws.title = "Summary" + ws.append(headers) + for row in ordered: + ws.append([row.get(key, "") for key in headers]) + wb.save(results_dir / "summary.xlsx") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--op-list", default=None, help="File with one pytest marker per line.") + parser.add_argument("--ops", default=None, help="Comma-separated pytest markers; overrides --op-list.") + parser.add_argument("--gpus", default="0", help="Comma-separated GPU ids for CUDA_VISIBLE_DEVICES.") + parser.add_argument("--mode", default="quick", choices=("quick", "normal")) + parser.add_argument("--results-dir", default=None) + parser.add_argument( + "--pytest-args", + default="", + help="Extra pytest args appended to every per-op invocation.", + ) + args = parser.parse_args() + + project_root = Path(__file__).resolve().parent + ops = read_ops(args.op_list, args.ops) + if not ops: + raise SystemExit("no operators to run") + gpus = [int(item.strip()) for item in args.gpus.split(",") if item.strip()] + if not gpus: + raise SystemExit("no GPUs provided") + results_dir = ( + Path(args.results_dir).resolve() + if args.results_dir + else project_root / f"pytest_results_{now_ts()}" + ) + ensure_dir(results_dir) + extra_pytest_args = shlex.split(args.pytest_args) if args.pytest_args else [] + + tasks = {gpu: [] for gpu in gpus} + for index, op in enumerate(ops): + tasks[gpus[index % len(gpus)]].append(op) + + results: list[dict[str, object]] = [] + with ThreadPoolExecutor(max_workers=len(gpus)) as executor: + futures = [] + for gpu, gpu_ops in tasks.items(): + if not gpu_ops: + continue + futures.append( + executor.submit( + run_gpu_ops, + project_root, + gpu, + gpu_ops, + args.mode, + results_dir, + extra_pytest_args, + results, + ) + ) + for future in as_completed(futures): + future.result() + + write_summary(results, results_dir) + return 1 if any(result["status"] in ("FAIL", "NO_TESTS", "CRASH") for result in results) else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/pytest/test_spgemm_sddmm_accuracy.py b/tests/pytest/test_spgemm_sddmm_accuracy.py new file mode 100644 index 0000000..42c6019 --- /dev/null +++ b/tests/pytest/test_spgemm_sddmm_accuracy.py @@ -0,0 +1,100 @@ +import pytest +import torch + +from flagsparse import flagsparse_sddmm_csr, flagsparse_spgemm_csr + +from tests.pytest.param_shapes import ( + SDDMM_DTYPES, + SDDMM_DTYPE_IDS, + SDDMM_MNK_SHAPES, + SPGEMM_DTYPES, + SPGEMM_DTYPE_IDS, + SPGEMM_MNK_SHAPES, +) + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _random_csr(rows, cols, dtype, device): + denom = max(rows * cols, 1) + p = min(0.25, max(0.06, 32.0 / denom)) + mask = torch.rand(rows, cols, device=device) < p + if int(mask.sum().item()) == 0: + mask[0, 0] = True + vals = torch.randn(rows, cols, dtype=dtype, device=device) * mask.to(dtype=dtype) + return vals.to_sparse_csr() + + +def _csr_to_dense(data, indices, indptr, shape): + csr = torch.sparse_csr_tensor( + indptr, + indices, + data, + size=shape, + dtype=data.dtype, + device=data.device, + ) + return csr.to_dense() + + +def _tol(dtype): + if dtype == torch.float32: + return 1e-4, 1e-4 + return 1e-10, 1e-8 + + +@pytest.mark.spgemm_csr +@pytest.mark.parametrize("M, N, K", SPGEMM_MNK_SHAPES) +@pytest.mark.parametrize("dtype", SPGEMM_DTYPES, ids=SPGEMM_DTYPE_IDS) +def test_spgemm_csr_matches_torch(M, N, K, dtype): + device = torch.device("cuda") + A = _random_csr(M, K, dtype, device) + B = _random_csr(K, N, dtype, device) + c_data, c_indices, c_indptr, c_shape = flagsparse_spgemm_csr( + A.values(), + A.col_indices().to(torch.int32), + A.crow_indices(), + (M, K), + B.values(), + B.col_indices().to(torch.int32), + B.crow_indices(), + (K, N), + ) + got = _csr_to_dense(c_data, c_indices, c_indptr, c_shape) + ref = torch.sparse.mm(A, B.to_dense()) + rtol, atol = _tol(dtype) + assert torch.allclose(got, ref, rtol=rtol, atol=atol) + + +@pytest.mark.sddmm_csr +@pytest.mark.parametrize("M, N, K", SDDMM_MNK_SHAPES) +@pytest.mark.parametrize("dtype", SDDMM_DTYPES, ids=SDDMM_DTYPE_IDS) +def test_sddmm_csr_matches_sampled_dense_reference(M, N, K, dtype): + device = torch.device("cuda") + pattern = _random_csr(M, N, dtype, device) + indices = pattern.col_indices().to(torch.int32) + indptr = pattern.crow_indices() + data = pattern.values() + x = torch.randn(M, K, dtype=dtype, device=device) + y = torch.randn(N, K, dtype=dtype, device=device) + alpha = 1.25 + beta = 0.5 + + got = flagsparse_sddmm_csr( + data=data, + indices=indices, + indptr=indptr, + x=x, + y=y, + shape=(M, N), + alpha=alpha, + beta=beta, + ) + row_ids = torch.repeat_interleave( + torch.arange(M, dtype=torch.int64, device=device), + indptr[1:] - indptr[:-1], + ) + ref = alpha * torch.sum(x[row_ids] * y[indices.to(torch.int64)], dim=1) + beta * data + rtol, atol = _tol(dtype) + assert torch.allclose(got, ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spmm_coo_accuracy.py b/tests/pytest/test_spmm_coo_accuracy.py new file mode 100644 index 0000000..3cde262 --- /dev/null +++ b/tests/pytest/test_spmm_coo_accuracy.py @@ -0,0 +1,45 @@ +import pytest +import torch + +from flagsparse import flagsparse_spmm_coo + +from tests.pytest.param_shapes import MNK_SHAPES, SPMM_OPT_DTYPES, SPMM_OPT_DTYPE_IDS + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _random_coo_mk(M, K, dtype, device): + denom = max(M * K, 1) + p = min(0.25, max(0.06, 32.0 / denom)) + mask = torch.rand(M, K, device=device) < p + if int(mask.sum().item()) == 0: + mask[0, 0] = True + vals = torch.randn(M, K, dtype=dtype, device=device) * mask.to(dtype=dtype) + return vals.to_sparse_coo().coalesce() + + +def _tol(dtype): + if dtype == torch.float32: + return 1e-4, 1e-4 + return 1e-10, 1e-8 + + +@pytest.mark.spmm_coo +@pytest.mark.parametrize("M, N, K", MNK_SHAPES) +@pytest.mark.parametrize("dtype", SPMM_OPT_DTYPES, ids=SPMM_OPT_DTYPE_IDS) +def test_spmm_coo_matches_torch(M, N, K, dtype): + device = torch.device("cuda") + Asp = _random_coo_mk(M, K, dtype, device) + indices = Asp.indices() + data = Asp.values() + row = indices[0].contiguous() + col = indices[1].contiguous() + B = torch.randn(K, N, dtype=dtype, device=device) + if dtype == torch.float32: + ref = torch.sparse.mm(Asp.double(), B.double()).float() + else: + ref = torch.sparse.mm(Asp, B) + out = flagsparse_spmm_coo(data, row, col, B, (M, K)) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spsm_accuracy.py b/tests/pytest/test_spsm_accuracy.py new file mode 100644 index 0000000..76e11e0 --- /dev/null +++ b/tests/pytest/test_spsm_accuracy.py @@ -0,0 +1,68 @@ +import pytest +import torch + +from flagsparse import flagsparse_spsm_coo, flagsparse_spsm_csr + +from tests.pytest.param_shapes import SPSM_N_RHS, TRIANGULAR_DTYPE_IDS, TRIANGULAR_DTYPES + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _build_lower_dense(n, dtype, device): + base = torch.tril(torch.randn(n, n, dtype=dtype, device=device)) + eye = torch.eye(n, dtype=dtype, device=device) + return base + eye * (float(n) * 0.5 + 2.0) + + +def _tol(dtype): + if dtype == torch.float32: + return 1e-4, 1e-5 + return 1e-10, 1e-10 + + +@pytest.mark.spsm +@pytest.mark.spsm_csr +@pytest.mark.parametrize("n, n_rhs", SPSM_N_RHS) +@pytest.mark.parametrize("dtype", TRIANGULAR_DTYPES, ids=TRIANGULAR_DTYPE_IDS) +def test_spsm_csr_lower_matches_dense(n, n_rhs, dtype): + device = torch.device("cuda") + A = _build_lower_dense(n, dtype, device) + B = torch.randn(n, n_rhs, dtype=dtype, device=device) + ref = torch.linalg.solve_triangular(A, B, upper=False) + Acsr = A.to_sparse_csr() + out = flagsparse_spsm_csr( + Acsr.values(), + Acsr.col_indices(), + Acsr.crow_indices(), + B, + (n, n), + lower=True, + unit_diagonal=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsm +@pytest.mark.spsm_coo +@pytest.mark.parametrize("n, n_rhs", SPSM_N_RHS) +@pytest.mark.parametrize("dtype", TRIANGULAR_DTYPES, ids=TRIANGULAR_DTYPE_IDS) +def test_spsm_coo_lower_matches_dense(n, n_rhs, dtype): + device = torch.device("cuda") + A = _build_lower_dense(n, dtype, device) + B = torch.randn(n, n_rhs, dtype=dtype, device=device) + ref = torch.linalg.solve_triangular(A, B, upper=False) + Acoo = A.to_sparse_coo().coalesce() + indices = Acoo.indices() + out = flagsparse_spsm_coo( + Acoo.values(), + indices[0].contiguous(), + indices[1].contiguous(), + B, + (n, n), + lower=True, + unit_diagonal=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) From 3bf7cafdd9f8e98e33234b6f85ed0d7202a73ea3 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 12:10:42 +0800 Subject: [PATCH 09/22] speedup and opt --- .gitignore | 10 + LICENSE | 201 ++ README.md | 114 + README_cn.md | 114 + benchmark/benchmark_gather.py | 4 + benchmark/benchmark_scatter.py | 4 + benchmark/benchmark_sddmm.py | 9 + benchmark/benchmark_spgemm.py | 9 + benchmark/benchmark_spmm.py | 9 + benchmark/benchmark_spmm_opt.py | 9 + benchmark/benchmark_spmv.py | 9 + benchmark/benchmark_spsm.py | 9 + benchmark/benchmark_spsv.py | 14 + ops_support.py | 395 +++ ops_support_sort_check.csv | 177 ++ pyproject.toml | 30 + pytest.ini | 13 + run_flagsparse_pytest.py | 264 ++ setup.py | 12 + src/flagsparse/__init__.py | 148 + src/flagsparse/sparse_formats.py | 670 +++++ src/flagsparse/sparse_operations/__init__.py | 100 + src/flagsparse/sparse_operations/_common.py | 334 +++ .../sparse_operations/benchmarks.py | 558 ++++ .../sparse_operations/gather_scatter.py | 541 ++++ src/flagsparse/sparse_operations/sddmm_csr.py | 572 ++++ .../sparse_operations/spgemm_csr.py | 1313 +++++++++ src/flagsparse/sparse_operations/spmm_coo.py | 1172 ++++++++ src/flagsparse/sparse_operations/spmm_csr.py | 2394 +++++++++++++++++ src/flagsparse/sparse_operations/spmv_coo.py | 348 +++ src/flagsparse/sparse_operations/spmv_csr.py | 758 ++++++ src/flagsparse/sparse_operations/spsm.py | 794 ++++++ src/flagsparse/sparse_operations/spsv.py | 1658 ++++++++++++ tests/__init__.py | 1 + tests/diagnose_spmm_opt.py | 628 +++++ tests/pytest/Note.md | 189 ++ tests/pytest/__init__.py | 1 + tests/pytest/conftest.py | 20 + tests/pytest/param_shapes.py | 47 + tests/pytest/test_gather_scatter_accuracy.py | 192 ++ tests/pytest/test_spgemm_sddmm_accuracy.py | 100 + tests/pytest/test_spmm_coo_accuracy.py | 45 + tests/pytest/test_spmm_csr_accuracy.py | 58 + tests/pytest/test_spmv_coo_accuracy.py | 46 + tests/pytest/test_spmv_csr_accuracy.py | 49 + tests/pytest/test_spsm_accuracy.py | 68 + tests/pytest/test_spsv_csr_accuracy.py | 447 +++ tests/test_gather.py | 548 ++++ tests/test_scatter.py | 397 +++ tests/test_sddmm.py | 692 +++++ tests/test_spgemm.py | 1779 ++++++++++++ tests/test_spmm.py | 1112 ++++++++ tests/test_spmm_coo.py | 1376 ++++++++++ tests/test_spmm_opt.py | 449 ++++ tests/test_spmv.py | 693 +++++ tests/test_spmv_coo.py | 589 ++++ tests/test_spmv_opt.py | 404 +++ tests/test_spsm.py | 529 ++++ tests/test_spsv.py | 1932 +++++++++++++ 59 files changed, 25157 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 README_cn.md create mode 100644 benchmark/benchmark_gather.py create mode 100644 benchmark/benchmark_scatter.py create mode 100644 benchmark/benchmark_sddmm.py create mode 100644 benchmark/benchmark_spgemm.py create mode 100644 benchmark/benchmark_spmm.py create mode 100644 benchmark/benchmark_spmm_opt.py create mode 100644 benchmark/benchmark_spmv.py create mode 100644 benchmark/benchmark_spsm.py create mode 100644 benchmark/benchmark_spsv.py create mode 100644 ops_support.py create mode 100644 ops_support_sort_check.csv create mode 100644 pyproject.toml create mode 100644 pytest.ini create mode 100644 run_flagsparse_pytest.py create mode 100644 setup.py create mode 100644 src/flagsparse/__init__.py create mode 100644 src/flagsparse/sparse_formats.py create mode 100644 src/flagsparse/sparse_operations/__init__.py create mode 100644 src/flagsparse/sparse_operations/_common.py create mode 100644 src/flagsparse/sparse_operations/benchmarks.py create mode 100644 src/flagsparse/sparse_operations/gather_scatter.py create mode 100644 src/flagsparse/sparse_operations/sddmm_csr.py create mode 100644 src/flagsparse/sparse_operations/spgemm_csr.py create mode 100644 src/flagsparse/sparse_operations/spmm_coo.py create mode 100644 src/flagsparse/sparse_operations/spmm_csr.py create mode 100644 src/flagsparse/sparse_operations/spmv_coo.py create mode 100644 src/flagsparse/sparse_operations/spmv_csr.py create mode 100644 src/flagsparse/sparse_operations/spsm.py create mode 100644 src/flagsparse/sparse_operations/spsv.py create mode 100644 tests/__init__.py create mode 100644 tests/diagnose_spmm_opt.py create mode 100644 tests/pytest/Note.md create mode 100644 tests/pytest/__init__.py create mode 100644 tests/pytest/conftest.py create mode 100644 tests/pytest/param_shapes.py create mode 100644 tests/pytest/test_gather_scatter_accuracy.py create mode 100644 tests/pytest/test_spgemm_sddmm_accuracy.py create mode 100644 tests/pytest/test_spmm_coo_accuracy.py create mode 100644 tests/pytest/test_spmm_csr_accuracy.py create mode 100644 tests/pytest/test_spmv_coo_accuracy.py create mode 100644 tests/pytest/test_spmv_csr_accuracy.py create mode 100644 tests/pytest/test_spsm_accuracy.py create mode 100644 tests/pytest/test_spsv_csr_accuracy.py create mode 100644 tests/test_gather.py create mode 100644 tests/test_scatter.py create mode 100644 tests/test_sddmm.py create mode 100644 tests/test_spgemm.py create mode 100644 tests/test_spmm.py create mode 100644 tests/test_spmm_coo.py create mode 100644 tests/test_spmm_opt.py create mode 100644 tests/test_spmv.py create mode 100644 tests/test_spmv_coo.py create mode 100644 tests/test_spmv_opt.py create mode 100644 tests/test_spsm.py create mode 100644 tests/test_spsv.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e44113e --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +build/ +dist/ +*.egg-info/ +*.egg +__pycache__/ +*.py[cod] +.pytest_cache/ +.coverage +htmlcov/ +.DS_Store diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..b0e21ea --- /dev/null +++ b/README.md @@ -0,0 +1,114 @@ +# FlagSparse + +GPU sparse operations package (SpMV, SpMM, SpGEMM, SDDMM, gather, scatter, sparse formats). + +## Install + +```bash +pip install . --no-deps --no-build-isolation +``` + +Use `--no-build-isolation` to avoid downloading build deps when offline. + +Runtime dependencies (install when needed): + +```bash +pip install torch triton cupy-cuda12x +``` + +## Layout + +- `src/flagsparse/` - core package (`sparse_operations/` is emitted as several `.py` modules from string literals in `flagsparse.py`) +- `tests/` - pytest tests +- `benchmark/` - performance benchmarks + +## Tests + +Run from project root, or `cd tests` then run scripts (paths like `../matrix` for .mtx dir). + +**test_spmv.py** - CSR SpMV (SuiteSparse `.mtx`, synthetic, or CSR CSV export): + +```bash +python tests/test_spmv.py # batch run, default float32 +python tests/test_spmv.py --dtype float64 # optional: --index-dtype int32|int64, --warmup, --iters, --no-cusparse +python tests/test_spmv.py --synthetic # synthetic benchmark +python tests/test_spmv.py --csv-csr results.csv # all value×index dtypes -> one CSV (per-matrix lines while running) +``` + +**test_spmv_coo.py** - COO SpMV (requires `--synthetic` or `--csv-coo`; no standalone `.mtx` batch): + +```bash +python tests/test_spmv_coo.py --synthetic +python tests/test_spmv_coo.py --csv-coo out.csv +``` + +**test_spmv_opt.py** - SpMV baseline vs optimised A/B (`float32` / `float64` only): + +```bash +python tests/test_spmv_opt.py [...] +python tests/test_spmv_opt.py --csv out.csv +``` + +**test_spmm.py** - CSR SpMM (`.mtx` batch, synthetic, or `--csv`): + +```bash +python tests/test_spmm.py +python tests/test_spmm.py --synthetic # optional: --skip-api-checks, --skip-alg1-coverage +python tests/test_spmm.py --csv results.csv # float32/float64 + int32 in CSV; per-matrix console output +# common options: --dtype, --index-dtype, --dense-cols, --block-n, --block-nnz, --max-segments, --warmup, --iters, --no-cusparse +``` + +**test_spmm_opt.py** - CSR SpMM baseline vs optimised A/B: + +```bash +python tests/test_spmm_opt.py --dense-cols 32 +python tests/test_spmm_opt.py --csv spmm_opt.csv # optional: --dtype float32|float64, --dense-cols +# common options: --dtype, --dense-cols, --warmup, --iters +``` + +**test_spmm_coo.py** - native COO SpMM: + +```bash +python tests/test_spmm_coo.py +python tests/test_spmm_coo.py --synthetic # optional: --route rowrun|atomic|compare, --skip-api-checks, --skip-coo-coverage +python tests/test_spmm_coo.py --csv out.csv # only --route rowrun or atomic (not compare) +# same tuning flags as CSR SpMM where applicable: --dense-cols, --block-n, --block-nnz, --warmup, --iters, --no-cusparse +``` + +**test_sddmm.py** - CSR SDDMM (`.mtx` batch or `--csv`): + +```bash +python tests/test_sddmm.py --k 64 +python tests/test_sddmm.py --csv out.csv # optional: --dtype float32|float64, --acc_mode f32|f64, --k 64 +# common options: --dtype, --index-dtype, --acc_mode, --k, --alpha, --beta, --warmup, --iters, --no-cupy-ref, --skip-api-checks +``` + +**test_spgemm.py** - CSR SpGEMM (`.mtx` batch or `--csv`): + +```bash +python tests/test_spgemm.py --input-mode auto +python tests/test_spgemm.py --csv results.csv # optional: --dtype float32|float64, --input-mode auto|a_equals_b|a_at, --compare-device cpu|gpu +# common options: --dtype, --index-dtype, --warmup, --iters, --input-mode, --adaptive-loops, --no-cusparse, --ref-blocked-retry, --ref-isolated-retry, --ref-block-rows, --compare-device, --run-api-checks +``` + +**test_spsv.py** - SpSV (triangular solve; **square** matrices only). CSR and COO share this script; there is **no** `test_spsv_coo.py`. + +```bash +python tests/test_spsv.py --synthetic +python tests/test_spsv.py --csv-csr spsv.csv +python tests/test_spsv.py --csv-coo out.csv # same CSV columns as CSR; optional --coo-mode auto|direct|csr (default auto) +``` + +**test_spsm.py** - SpSM (triangular matrix-matrix solve; **square** matrices only): + +```bash +python tests/test_spsm.py --synthetic --n 512 --rhs 32 +python tests/test_spsm.py --csv-csr spsm_csr.csv --rhs 32 +python tests/test_spsm.py --csv-coo spsm_coo.csv --rhs 32 +``` + +**test_gather.py** / **test_scatter.py** - gather/scatter benchmarks (pytest or `python tests/test_gather.py`). + +## License + +This project is licensed under the [Apache (Version 2.0) license](./LICENSE). diff --git a/README_cn.md b/README_cn.md new file mode 100644 index 0000000..87478d6 --- /dev/null +++ b/README_cn.md @@ -0,0 +1,114 @@ +# FlagSparse + +GPU 稀疏运算库(SpMV、SpMM、SpGEMM、SDDMM、gather、scatter、多种稀疏格式)。 + +## 安装 + +```bash +pip install . --no-deps --no-build-isolation +``` + +离线时可加 `--no-build-isolation` 避免拉取构建依赖。 + +运行时依赖(按需安装): + +```bash +pip install torch triton cupy-cuda12x +``` + +## 目录说明 + +- `src/flagsparse/` - 核心包(`sparse_operations/` 由 `flagsparse.py` 内嵌字符串生成多个 `.py`) +- `tests/` - pytest 测试 +- `benchmark/` - 性能基准 + +## 测试用法 + +在项目根目录执行,或先 `cd tests` 再运行脚本(.mtx 目录可用 `../matrix` 等相对路径)。 + +**test_spmv.py** - CSR SpMV(SuiteSparse `.mtx`、合成数据或 CSR CSV): + +```bash +python tests/test_spmv.py <目录或文件.mtx> # 批量跑,默认 float32 +python tests/test_spmv.py <目录/> --dtype float64 # 可选:--index-dtype int32|int64、--warmup、--iters、--no-cusparse +python tests/test_spmv.py --synthetic # 合成基准 +python tests/test_spmv.py <目录/> --csv-csr results.csv # 全部 value×index dtype 写入一个 CSV(运行过程中逐矩阵打印) +``` + +**test_spmv_coo.py** - COO SpMV(需 `--synthetic` 或 `--csv-coo`,不能单独批量跑 .mtx): + +```bash +python tests/test_spmv_coo.py --synthetic +python tests/test_spmv_coo.py <目录/> --csv-coo out.csv +``` + +**test_spmv_opt.py** - SpMV 基线 vs 优化对比(仅 `float32` / `float64`): + +```bash +python tests/test_spmv_opt.py <目录或文件.mtx> [...] +python tests/test_spmv_opt.py <目录/> --csv out.csv +``` + +**test_spmm.py** - CSR SpMM(`.mtx` 批量、合成或 `--csv`): + +```bash +python tests/test_spmm.py <目录或文件.mtx> +python tests/test_spmm.py --synthetic # 可选:--skip-api-checks、--skip-alg1-coverage +python tests/test_spmm.py <目录/> --csv results.csv # CSV 内为 float32/float64 + int32;控制台逐矩阵输出 +# 常用选项:--dtype、--index-dtype、--dense-cols、--block-n、--block-nnz、--max-segments、--warmup、--iters、--no-cusparse +``` + +**test_spmm_opt.py** - CSR SpMM 基线与优化版 A/B 对比: + +```bash +python tests/test_spmm_opt.py <目录或文件.mtx> --dense-cols 32 +python tests/test_spmm_opt.py <目录/> --csv spmm_opt.csv # 可选:--dtype float32|float64、--dense-cols +# 常用选项:--dtype、--dense-cols、--warmup、--iters +``` + +**test_spmm_coo.py** - 原生 COO SpMM: + +```bash +python tests/test_spmm_coo.py <目录或文件.mtx> +python tests/test_spmm_coo.py --synthetic # 可选:--route rowrun|atomic|compare、--skip-api-checks、--skip-coo-coverage +python tests/test_spmm_coo.py <目录/> --csv out.csv # 仅支持 --route rowrun 或 atomic(compare 不能配 --csv) +# 与 CSR SpMM 类似的调参:--dense-cols、--block-n、--block-nnz、--warmup、--iters、--no-cusparse +``` + +**test_sddmm.py** - CSR SDDMM(`.mtx` 批量或 `--csv`): + +```bash +python tests/test_sddmm.py <目录或文件.mtx> --k 64 +python tests/test_sddmm.py <目录/> --csv out.csv # 可选:--dtype float32|float64、--acc_mode f32|f64、--k 64 +# 常用选项:--dtype、--index-dtype、--acc_mode、--k、--alpha、--beta、--warmup、--iters、--no-cupy-ref、--skip-api-checks +``` + +**test_spgemm.py** - CSR SpGEMM(`.mtx` 批量或 `--csv`): + +```bash +python tests/test_spgemm.py <目录或文件.mtx> --input-mode auto +python tests/test_spgemm.py <目录/> --csv results.csv # 可选:--dtype float32|float64、--input-mode auto|a_equals_b|a_at、--compare-device cpu|gpu +# 常用选项:--dtype、--index-dtype、--warmup、--iters、--input-mode、--adaptive-loops、--no-cusparse、--ref-blocked-retry、--ref-isolated-retry、--ref-block-rows、--compare-device、--run-api-checks +``` + +**test_spsv.py** - SpSV(三角求解;**仅方阵**)。CSR 与 COO 共用本脚本;**不存在** `test_spsv_coo.py`。 + +```bash +python tests/test_spsv.py --synthetic +python tests/test_spsv.py <目录/> --csv-csr spsv.csv +python tests/test_spsv.py <目录/> --csv-coo out.csv # 列与 CSR 相同;可选 --coo-mode auto|direct|csr(默认 auto) +``` + +**test_spsm.py** - SpSM(三角矩阵-稠密矩阵求解;**仅方阵**): + +```bash +python tests/test_spsm.py --synthetic --n 512 --rhs 32 +python tests/test_spsm.py <目录/> --csv-csr spsm_csr.csv --rhs 32 +python tests/test_spsm.py <目录/> --csv-coo spsm_coo.csv --rhs 32 +``` + +**test_gather.py** / **test_scatter.py** - gather/scatter 基准(pytest 或 `python tests/test_gather.py`)。 + +## 授权许可 + +本项目采用 [Apache (Version 2.0) license](./LICENSE) 许可证授权。 diff --git a/benchmark/benchmark_gather.py b/benchmark/benchmark_gather.py new file mode 100644 index 0000000..0f4e6fb --- /dev/null +++ b/benchmark/benchmark_gather.py @@ -0,0 +1,4 @@ +"""Run gather benchmark (FlagSparse vs PyTorch vs cuSPARSE).""" +import flagsparse as fs +if __name__ == "__main__": + fs.comprehensive_gather_test() diff --git a/benchmark/benchmark_scatter.py b/benchmark/benchmark_scatter.py new file mode 100644 index 0000000..50412e2 --- /dev/null +++ b/benchmark/benchmark_scatter.py @@ -0,0 +1,4 @@ +"""Run scatter benchmark (FlagSparse vs PyTorch vs cuSPARSE).""" +import flagsparse as fs +if __name__ == "__main__": + fs.comprehensive_scatter_test() diff --git a/benchmark/benchmark_sddmm.py b/benchmark/benchmark_sddmm.py new file mode 100644 index 0000000..8ed2b95 --- /dev/null +++ b/benchmark/benchmark_sddmm.py @@ -0,0 +1,9 @@ +"""Run SDDMM benchmark (.mtx batch). From project root: python benchmark/benchmark_sddmm.py --csv results.csv.""" +import sys +from pathlib import Path +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) +from tests.test_sddmm import main +if __name__ == "__main__": + main() diff --git a/benchmark/benchmark_spgemm.py b/benchmark/benchmark_spgemm.py new file mode 100644 index 0000000..26db5be --- /dev/null +++ b/benchmark/benchmark_spgemm.py @@ -0,0 +1,9 @@ +"""Run SpGEMM benchmark (.mtx batch). From project root: python benchmark/benchmark_spgemm.py --csv results.csv.""" +import sys +from pathlib import Path +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) +from tests.test_spgemm import main +if __name__ == "__main__": + main() diff --git a/benchmark/benchmark_spmm.py b/benchmark/benchmark_spmm.py new file mode 100644 index 0000000..56ce2f4 --- /dev/null +++ b/benchmark/benchmark_spmm.py @@ -0,0 +1,9 @@ +"""Run SpMM benchmark (synthetic or .mtx batch). From project root: python benchmark/benchmark_spmm.py [--synthetic] [*.mtx].""" +import sys +from pathlib import Path +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) +from tests.test_spmm import main +if __name__ == "__main__": + main() diff --git a/benchmark/benchmark_spmm_opt.py b/benchmark/benchmark_spmm_opt.py new file mode 100644 index 0000000..c9cc8fa --- /dev/null +++ b/benchmark/benchmark_spmm_opt.py @@ -0,0 +1,9 @@ +"""Run SpMM-opt A/B benchmark (.mtx batch). From project root: python benchmark/benchmark_spmm_opt.py --csv results.csv.""" +import sys +from pathlib import Path +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) +from tests.test_spmm_opt import main +if __name__ == "__main__": + main() diff --git a/benchmark/benchmark_spmv.py b/benchmark/benchmark_spmv.py new file mode 100644 index 0000000..9de5d02 --- /dev/null +++ b/benchmark/benchmark_spmv.py @@ -0,0 +1,9 @@ +"""Run SpMV benchmark (synthetic or .mtx batch). From project root: python benchmark/benchmark_spmv.py [--synthetic] [*.mtx].""" +import sys +from pathlib import Path +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) +from tests.test_spmv import main +if __name__ == "__main__": + main() diff --git a/benchmark/benchmark_spsm.py b/benchmark/benchmark_spsm.py new file mode 100644 index 0000000..100a3bd --- /dev/null +++ b/benchmark/benchmark_spsm.py @@ -0,0 +1,9 @@ +"""Run SpSM synthetic benchmark. From project root: python benchmark/benchmark_spsm.py [--n 512 --rhs 32].""" +import sys +from pathlib import Path +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) +from tests.test_spsm import main +if __name__ == "__main__": + main() diff --git a/benchmark/benchmark_spsv.py b/benchmark/benchmark_spsv.py new file mode 100644 index 0000000..3fa258f --- /dev/null +++ b/benchmark/benchmark_spsv.py @@ -0,0 +1,14 @@ +"""Run SpSV benchmark. From project root: python benchmark/benchmark_spsv.py [--synthetic | --csv-csr out.csv].""" + +import sys +from pathlib import Path + +root = Path(__file__).resolve().parent.parent +if str(root) not in sys.path: + sys.path.insert(0, str(root)) + +from tests.test_spsv import main + + +if __name__ == "__main__": + main() diff --git a/ops_support.py b/ops_support.py new file mode 100644 index 0000000..62be4fb --- /dev/null +++ b/ops_support.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +"""Export the declared FlagSparse sparse-operator support matrix to CSV. + +The script is intentionally static: it does not import torch/triton/cupy or +flagsparse, and it never launches kernels. It reads source files and reports the +support declared by constants such as SUPPORTED_*_VALUE_DTYPES. +""" + +from __future__ import annotations + +import argparse +import ast +import csv +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterable + + +CSV_FIELDS = ( + "operator", + "format", + "index_dtype", + "value_dtype", + "op", + "route", + "status", +) + +DEFAULT_VALUE_DTYPES = ("float16", "bfloat16", "float32", "float64", "complex64", "complex128") +DEFAULT_INDEX_DTYPES = ("int32", "int64") +NA = "N/A" +VALUE_DTYPE_ORDER = { + "float16": 0, + "bfloat16": 1, + "float32": 2, + "float64": 3, + "complex32": 4, + "complex64": 5, + "complex128": 6, +} +INDEX_DTYPE_ORDER = {"int32": 0, "int64": 1} +OP_ORDER = {"non": 0, "trans": 1, "conj": 2} + + +@dataclass(frozen=True) +class ApiSpec: + operator: str + api: str + module: str + fmt: str + route: str + value_const: str | None = None + index_const: str | None = None + values: tuple[str, ...] | None = None + indices: tuple[str, ...] | None = None + ops: tuple[str, ...] | str | None = None + notes: str = "" + + +class SourceModule: + """Small AST-backed view of one sparse operation module.""" + + def __init__(self, path: Path, shared: dict[str, Any] | None = None): + self.path = path + self.name = path.stem + self.tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + self.assignments: dict[str, ast.AST] = {} + self.functions: set[str] = set() + self.shared = shared or {} + self._collect() + + def _collect(self) -> None: + for node in self.tree.body: + if isinstance(node, (ast.Assign, ast.AnnAssign)): + targets = node.targets if isinstance(node, ast.Assign) else [node.target] + for target in targets: + if isinstance(target, ast.Name): + self.assignments[target.id] = node.value + elif isinstance(node, ast.FunctionDef): + self.functions.add(node.name) + + def get(self, name: str) -> Any: + if name in self.assignments: + return self._eval(self.assignments[name]) + return self.shared.get(name) + + def _eval(self, node: ast.AST) -> Any: + if isinstance(node, ast.Tuple): + return tuple(self._flatten(el) for el in node.elts) + if isinstance(node, ast.List): + return tuple(self._flatten(el) for el in node.elts) + if isinstance(node, ast.Set): + return tuple(self._flatten(el) for el in node.elts) + if isinstance(node, ast.Dict): + return {self._eval(k): self._eval(v) for k, v in zip(node.keys, node.values) if k is not None} + if isinstance(node, ast.Starred): + return self._eval(node.value) + if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and node.value.id == "torch": + return node.attr + if isinstance(node, ast.Name): + if node.id in self.assignments: + return self._eval(self.assignments[node.id]) + if node.id in self.shared: + return self.shared[node.id] + return node.id + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.Call): + return self._eval_call(node) + if isinstance(node, ast.IfExp): + # Optional complex32/chalf support is represented as an if-expression. + # Report the declared capability conservatively, independent of the + # local torch build where this script is executed. + return self._eval(node.body) + if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add): + left = self._as_tuple(self._eval(node.left)) + right = self._as_tuple(self._eval(node.right)) + return left + right + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + value = self._eval(node.operand) + return -value if isinstance(value, (int, float)) else value + return None + + def _eval_call(self, node: ast.Call) -> Any: + if isinstance(node.func, ast.Name): + if node.func.id == "_torch_complex32_dtype": + return "complex32" + if node.func.id == "tuple" and node.args: + return self._as_tuple(self._eval(node.args[0])) + return None + + def _flatten(self, node: ast.AST) -> Any: + value = self._eval(node) + if isinstance(node, ast.Starred): + return value + return value + + @staticmethod + def _as_tuple(value: Any) -> tuple[Any, ...]: + if value is None: + return () + if isinstance(value, tuple): + return value + if isinstance(value, list): + return tuple(value) + return (value,) + + +def normalize_dtype_values(value: Any) -> tuple[str, ...]: + raw = flatten(value) + dtypes: list[str] = [] + for item in raw: + if item is None: + continue + token = str(item).replace("torch.", "") + if token == "chalf": + token = "complex32" + if token not in dtypes: + dtypes.append(token) + return tuple(dtypes) + + +def flatten(value: Any) -> tuple[Any, ...]: + if value is None: + return () + if isinstance(value, (tuple, list, set)): + out: list[Any] = [] + for item in value: + out.extend(flatten(item)) + return tuple(out) + return (value,) + + +def discover_modules(src_root: Path) -> dict[str, SourceModule]: + common = SourceModule(src_root / "_common.py") + common_values = normalize_dtype_values(common.get("SUPPORTED_VALUE_DTYPES")) + if "complex64" not in common_values: + # _common.py builds this list with append/extend after initial assignment. + # Static expression evaluation intentionally avoids executing that code, + # so keep this fallback aligned with the declared global support set. + common_values = DEFAULT_VALUE_DTYPES + ("complex32",) + shared = { + "SUPPORTED_VALUE_DTYPES": common_values, + "SUPPORTED_INDEX_DTYPES": normalize_dtype_values(common.get("SUPPORTED_INDEX_DTYPES")), + } + modules = {"_common": common} + for path in sorted(src_root.glob("*.py")): + if path.name == "_common.py": + continue + modules[path.stem] = SourceModule(path, shared=shared) + return modules + + +def collect_public_apis(src_root: Path) -> set[str]: + init_path = src_root / "__init__.py" + if not init_path.exists(): + return set() + init_module = SourceModule(init_path) + values = normalize_dtype_values(init_module.get("__all__")) + return set(values) + + +def op_names(module: SourceModule) -> tuple[str, ...]: + names = module.get("SPMV_OP_NAMES") + if isinstance(names, dict): + values = [str(v) for _, v in sorted(names.items(), key=lambda item: item[0])] + if values: + return tuple(values) + return ("non", "trans", "conj") + + +def normalize_op_label(op: Any) -> str: + token = str(op).strip().lower().replace("-", "_") + token = token.replace(" ", "_") + if token in ("n/a", "na", "none"): + return NA + if token in ("n", "non", "non_trans", "non_transpose"): + return "non" + if token in ("t", "trans", "transpose"): + return "trans" + if token in ("c", "conj", "conj_trans", "conj_transpose", "conjugate_transpose"): + return "conj" + return token + + +def registry(modules: dict[str, SourceModule]) -> tuple[ApiSpec, ...]: + spmv_ops = op_names(modules["spmv_csr"]) if "spmv_csr" in modules else ("non", "trans", "conj") + spmm_values = normalize_dtype_values(modules["spmm_csr"].get("SUPPORTED_SPMM_VALUE_DTYPES")) if "spmm_csr" in modules else None + return ( + ApiSpec("gather", "flagsparse_gather", "gather_scatter", "index", "triton", values=DEFAULT_VALUE_DTYPES + ("complex32",), indices=DEFAULT_INDEX_DTYPES), + ApiSpec("scatter", "flagsparse_scatter", "gather_scatter", "index", "triton", value_const="SUPPORTED_SCATTER_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES"), + ApiSpec("spmv", "flagsparse_spmv_csr", "spmv_csr", "CSR", "triton", value_const="SUPPORTED_SPMV_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=spmv_ops, notes="op supports non/trans/conj; conj on real dtypes is transpose-equivalent"), + ApiSpec("spmv", "flagsparse_spmv_coo", "spmv_coo", "COO", "triton", values=("float32", "float64"), indices=DEFAULT_INDEX_DTYPES, ops=("non",), notes="COO path casts kernel indices to int32 after range check"), + ApiSpec("spmv", "flagsparse_spmv_coo_tocsr", "spmv_csr", "COO->CSR", "triton", value_const="SUPPORTED_SPMV_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",), notes="COO input is converted to CSR before compute"), + ApiSpec("spmm", "flagsparse_spmm_csr", "spmm_csr", "CSR", "triton", value_const="SUPPORTED_SPMM_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",)), + ApiSpec("spmm", "flagsparse_spmm_csr_opt", "spmm_csr", "CSR", "triton_opt", value_const="SUPPORTED_SPMM_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",)), + ApiSpec("spmm", "flagsparse_spmm_coo", "spmm_coo", "COO", "triton", values=spmm_values, index_const="SUPPORTED_INDEX_DTYPES", ops=("non",), notes="COO SpMM reuses CSR SpMM dtype declaration"), + ApiSpec("spgemm", "flagsparse_spgemm_csr", "spgemm_csr", "CSR", "triton", value_const="SUPPORTED_SPGEMM_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",)), + ApiSpec("sddmm", "flagsparse_sddmm_csr", "sddmm_csr", "CSR", "triton", value_const="SUPPORTED_SDDMM_VALUE_DTYPES", index_const="SUPPORTED_INDEX_DTYPES", ops=("non",)), + ApiSpec("spsv", "flagsparse_spsv_csr", "spsv", "CSR", "triton", value_const="SUPPORTED_SPSV_VALUE_DTYPES", index_const="SUPPORTED_SPSV_INDEX_DTYPES", ops=("NON_TRANS", "TRANS"), notes="TRANS support is narrower than NON_TRANS; see combo constants"), + ApiSpec("spsv", "flagsparse_spsv_coo", "spsv", "COO", "triton", value_const="SUPPORTED_SPSV_VALUE_DTYPES", index_const="SUPPORTED_SPSV_INDEX_DTYPES", ops=("NON_TRANS", "TRANS"), notes="TRANS support is narrower than NON_TRANS; see combo constants"), + ApiSpec("spsm", "flagsparse_spsm_csr", "spsm", "CSR", "triton", value_const="SUPPORTED_SPSM_VALUE_DTYPES", index_const="SUPPORTED_SPSM_INDEX_DTYPES", ops=("NON_TRANS",), notes="opA/opB must both be NON_TRANS; row-major dense layout only"), + ApiSpec("spsm", "flagsparse_spsm_coo", "spsm", "COO", "triton", value_const="SUPPORTED_SPSM_VALUE_DTYPES", index_const="SUPPORTED_SPSM_INDEX_DTYPES", ops=("NON_TRANS",), notes="opA/opB must both be NON_TRANS; row-major dense layout only"), + ) + + +def rows_for_spec(spec: ApiSpec, modules: dict[str, SourceModule], public_apis: set[str], src_root: Path) -> list[dict[str, str]]: + module = modules.get(spec.module) + notes = [spec.notes] if spec.notes else [] + status = "SUPPORTED" + if module is None: + return [row(spec, NA, NA, NA, "PARTIAL")] + + values = spec.values + if values is None and spec.value_const: + values = normalize_dtype_values(module.get(spec.value_const)) + indices = spec.indices + if indices is None and spec.index_const: + indices = normalize_dtype_values(module.get(spec.index_const)) + if not values: + values = (NA,) + status = "PARTIAL" + notes.append(f"value dtype constant {spec.value_const or ''} not found") + if not indices: + indices = (NA,) + status = "PARTIAL" + notes.append(f"index dtype constant {spec.index_const or ''} not found") + if spec.api not in module.functions and spec.api not in public_apis: + status = "PARTIAL" + notes.append(f"api {spec.api} not found in source exports/functions") + + ops = spec.ops + if ops is None: + ops_tuple = ("non",) + elif isinstance(ops, str): + ops_tuple = (ops,) + else: + ops_tuple = ops + + return [ + row(spec, value_dtype, index_dtype, op, status) + for value_dtype in values + for index_dtype in indices + for op in ops_tuple + ] + + +def row(spec: ApiSpec, value_dtype: str, index_dtype: str, op: str, status: str) -> dict[str, str]: + return { + "operator": spec.operator, + "format": spec.fmt, + "index_dtype": str(index_dtype), + "value_dtype": str(value_dtype), + "op": normalize_op_label(op), + "route": spec.route, + "status": status, + } + + +def discovered_unmapped_rows(modules: dict[str, SourceModule], specs: Iterable[ApiSpec], src_root: Path) -> list[dict[str, str]]: + mapped = {(spec.module, spec.value_const) for spec in specs if spec.value_const} + mapped |= {(spec.module, spec.index_const) for spec in specs if spec.index_const} + out: list[dict[str, str]] = [] + for module_name, module in sorted(modules.items()): + if module_name in {"_common", "__init__", "benchmarks"}: + continue + for const_name in sorted(module.assignments): + if not (const_name.startswith("SUPPORTED_") and const_name.endswith("_DTYPES")): + continue + if (module_name, const_name) in mapped: + continue + values = normalize_dtype_values(module.get(const_name)) + dtype_kind = "index_dtype" if "INDEX" in const_name else "value_dtype" + for dtype in values or (NA,): + data = { + "operator": module_name, + "format": NA, + "value_dtype": dtype if dtype_kind == "value_dtype" else NA, + "index_dtype": dtype if dtype_kind == "index_dtype" else NA, + "op": NA, + "route": NA, + "status": "DISCOVERED_UNMAPPED", + } + out.append(data) + return out + + +def _ordered_value(mapping: dict[str, int], value: str) -> tuple[int, str]: + return (mapping.get(value, len(mapping)), value) + + +def _sort_rows(rows: list[dict[str, str]]) -> list[dict[str, str]]: + return sorted( + rows, + key=lambda item: ( + item["operator"], + item["format"], + _ordered_value(INDEX_DTYPE_ORDER, item["index_dtype"]), + _ordered_value(VALUE_DTYPE_ORDER, item["value_dtype"]), + _ordered_value(OP_ORDER, item["op"]), + item["route"], + item["status"], + ), + ) + + +def build_rows(src_root: Path) -> list[dict[str, str]]: + modules = discover_modules(src_root) + public_apis = collect_public_apis(src_root) + specs = registry(modules) + rows: list[dict[str, str]] = [] + for spec in specs: + rows.extend(rows_for_spec(spec, modules, public_apis, src_root)) + rows.extend(discovered_unmapped_rows(modules, specs, src_root)) + return _sort_rows(rows) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + script_dir = Path(__file__).resolve().parent + parser.add_argument( + "--src-root", + type=Path, + default=script_dir / "src" / "flagsparse" / "sparse_operations", + help="Path to src/flagsparse/sparse_operations (default: project-local path).", + ) + parser.add_argument( + "--output", + type=Path, + default=script_dir / "ops_support.csv", + help="CSV output path (default: ./ops_support.csv).", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + src_root = args.src_root.resolve() + output = args.output.resolve() + if not src_root.exists(): + raise FileNotFoundError(f"sparse operations source root not found: {src_root}") + + rows = build_rows(src_root) + output.parent.mkdir(parents=True, exist_ok=True) + with output.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=CSV_FIELDS) + writer.writeheader() + writer.writerows(rows) + + print(f"Wrote {len(rows)} support rows to {output}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/ops_support_sort_check.csv b/ops_support_sort_check.csv new file mode 100644 index 0000000..61ddfaa --- /dev/null +++ b/ops_support_sort_check.csv @@ -0,0 +1,177 @@ +operator,format,index_dtype,value_dtype,op,route,status +gather,index,int32,float16,non,triton,SUPPORTED +gather,index,int32,bfloat16,non,triton,SUPPORTED +gather,index,int32,float32,non,triton,SUPPORTED +gather,index,int32,float64,non,triton,SUPPORTED +gather,index,int32,complex32,non,triton,SUPPORTED +gather,index,int32,complex64,non,triton,SUPPORTED +gather,index,int32,complex128,non,triton,SUPPORTED +gather,index,int64,float16,non,triton,SUPPORTED +gather,index,int64,bfloat16,non,triton,SUPPORTED +gather,index,int64,float32,non,triton,SUPPORTED +gather,index,int64,float64,non,triton,SUPPORTED +gather,index,int64,complex32,non,triton,SUPPORTED +gather,index,int64,complex64,non,triton,SUPPORTED +gather,index,int64,complex128,non,triton,SUPPORTED +scatter,index,int32,float16,non,triton,SUPPORTED +scatter,index,int32,bfloat16,non,triton,SUPPORTED +scatter,index,int32,float32,non,triton,SUPPORTED +scatter,index,int32,float64,non,triton,SUPPORTED +scatter,index,int32,complex64,non,triton,SUPPORTED +scatter,index,int32,complex128,non,triton,SUPPORTED +scatter,index,int64,float16,non,triton,SUPPORTED +scatter,index,int64,bfloat16,non,triton,SUPPORTED +scatter,index,int64,float32,non,triton,SUPPORTED +scatter,index,int64,float64,non,triton,SUPPORTED +scatter,index,int64,complex64,non,triton,SUPPORTED +scatter,index,int64,complex128,non,triton,SUPPORTED +sddmm,CSR,int32,float32,non,triton,SUPPORTED +sddmm,CSR,int32,float64,non,triton,SUPPORTED +sddmm,CSR,int64,float32,non,triton,SUPPORTED +sddmm,CSR,int64,float64,non,triton,SUPPORTED +spgemm,CSR,int32,float32,non,triton,SUPPORTED +spgemm,CSR,int32,float64,non,triton,SUPPORTED +spgemm,CSR,int64,float32,non,triton,SUPPORTED +spgemm,CSR,int64,float64,non,triton,SUPPORTED +spmm,COO,int32,float16,non,triton,SUPPORTED +spmm,COO,int32,bfloat16,non,triton,SUPPORTED +spmm,COO,int32,float32,non,triton,SUPPORTED +spmm,COO,int32,float64,non,triton,SUPPORTED +spmm,COO,int32,complex32,non,triton,SUPPORTED +spmm,COO,int32,complex64,non,triton,SUPPORTED +spmm,COO,int32,complex128,non,triton,SUPPORTED +spmm,COO,int64,float16,non,triton,SUPPORTED +spmm,COO,int64,bfloat16,non,triton,SUPPORTED +spmm,COO,int64,float32,non,triton,SUPPORTED +spmm,COO,int64,float64,non,triton,SUPPORTED +spmm,COO,int64,complex32,non,triton,SUPPORTED +spmm,COO,int64,complex64,non,triton,SUPPORTED +spmm,COO,int64,complex128,non,triton,SUPPORTED +spmm,CSR,int32,float16,non,triton,SUPPORTED +spmm,CSR,int32,float16,non,triton_opt,SUPPORTED +spmm,CSR,int32,bfloat16,non,triton,SUPPORTED +spmm,CSR,int32,bfloat16,non,triton_opt,SUPPORTED +spmm,CSR,int32,float32,non,triton,SUPPORTED +spmm,CSR,int32,float32,non,triton_opt,SUPPORTED +spmm,CSR,int32,float64,non,triton,SUPPORTED +spmm,CSR,int32,float64,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex32,non,triton,SUPPORTED +spmm,CSR,int32,complex32,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex64,non,triton,SUPPORTED +spmm,CSR,int32,complex64,non,triton_opt,SUPPORTED +spmm,CSR,int32,complex128,non,triton,SUPPORTED +spmm,CSR,int32,complex128,non,triton_opt,SUPPORTED +spmm,CSR,int64,float16,non,triton,SUPPORTED +spmm,CSR,int64,float16,non,triton_opt,SUPPORTED +spmm,CSR,int64,bfloat16,non,triton,SUPPORTED +spmm,CSR,int64,bfloat16,non,triton_opt,SUPPORTED +spmm,CSR,int64,float32,non,triton,SUPPORTED +spmm,CSR,int64,float32,non,triton_opt,SUPPORTED +spmm,CSR,int64,float64,non,triton,SUPPORTED +spmm,CSR,int64,float64,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex32,non,triton,SUPPORTED +spmm,CSR,int64,complex32,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex64,non,triton,SUPPORTED +spmm,CSR,int64,complex64,non,triton_opt,SUPPORTED +spmm,CSR,int64,complex128,non,triton,SUPPORTED +spmm,CSR,int64,complex128,non,triton_opt,SUPPORTED +spmv,COO,int32,float32,non,triton,SUPPORTED +spmv,COO,int32,float64,non,triton,SUPPORTED +spmv,COO,int64,float32,non,triton,SUPPORTED +spmv,COO,int64,float64,non,triton,SUPPORTED +spmv,COO->CSR,int32,float16,non,triton,SUPPORTED +spmv,COO->CSR,int32,bfloat16,non,triton,SUPPORTED +spmv,COO->CSR,int32,float32,non,triton,SUPPORTED +spmv,COO->CSR,int32,float64,non,triton,SUPPORTED +spmv,COO->CSR,int32,complex64,non,triton,SUPPORTED +spmv,COO->CSR,int32,complex128,non,triton,SUPPORTED +spmv,COO->CSR,int64,float16,non,triton,SUPPORTED +spmv,COO->CSR,int64,bfloat16,non,triton,SUPPORTED +spmv,COO->CSR,int64,float32,non,triton,SUPPORTED +spmv,COO->CSR,int64,float64,non,triton,SUPPORTED +spmv,COO->CSR,int64,complex64,non,triton,SUPPORTED +spmv,COO->CSR,int64,complex128,non,triton,SUPPORTED +spmv,CSR,int32,float16,non,triton,SUPPORTED +spmv,CSR,int32,float16,trans,triton,SUPPORTED +spmv,CSR,int32,float16,conj,triton,SUPPORTED +spmv,CSR,int32,bfloat16,non,triton,SUPPORTED +spmv,CSR,int32,bfloat16,trans,triton,SUPPORTED +spmv,CSR,int32,bfloat16,conj,triton,SUPPORTED +spmv,CSR,int32,float32,non,triton,SUPPORTED +spmv,CSR,int32,float32,trans,triton,SUPPORTED +spmv,CSR,int32,float32,conj,triton,SUPPORTED +spmv,CSR,int32,float64,non,triton,SUPPORTED +spmv,CSR,int32,float64,trans,triton,SUPPORTED +spmv,CSR,int32,float64,conj,triton,SUPPORTED +spmv,CSR,int32,complex64,non,triton,SUPPORTED +spmv,CSR,int32,complex64,trans,triton,SUPPORTED +spmv,CSR,int32,complex64,conj,triton,SUPPORTED +spmv,CSR,int32,complex128,non,triton,SUPPORTED +spmv,CSR,int32,complex128,trans,triton,SUPPORTED +spmv,CSR,int32,complex128,conj,triton,SUPPORTED +spmv,CSR,int64,float16,non,triton,SUPPORTED +spmv,CSR,int64,float16,trans,triton,SUPPORTED +spmv,CSR,int64,float16,conj,triton,SUPPORTED +spmv,CSR,int64,bfloat16,non,triton,SUPPORTED +spmv,CSR,int64,bfloat16,trans,triton,SUPPORTED +spmv,CSR,int64,bfloat16,conj,triton,SUPPORTED +spmv,CSR,int64,float32,non,triton,SUPPORTED +spmv,CSR,int64,float32,trans,triton,SUPPORTED +spmv,CSR,int64,float32,conj,triton,SUPPORTED +spmv,CSR,int64,float64,non,triton,SUPPORTED +spmv,CSR,int64,float64,trans,triton,SUPPORTED +spmv,CSR,int64,float64,conj,triton,SUPPORTED +spmv,CSR,int64,complex64,non,triton,SUPPORTED +spmv,CSR,int64,complex64,trans,triton,SUPPORTED +spmv,CSR,int64,complex64,conj,triton,SUPPORTED +spmv,CSR,int64,complex128,non,triton,SUPPORTED +spmv,CSR,int64,complex128,trans,triton,SUPPORTED +spmv,CSR,int64,complex128,conj,triton,SUPPORTED +spsm,COO,int32,float32,non,triton,SUPPORTED +spsm,COO,int32,float64,non,triton,SUPPORTED +spsm,COO,int64,float32,non,triton,SUPPORTED +spsm,COO,int64,float64,non,triton,SUPPORTED +spsm,CSR,int32,float32,non,triton,SUPPORTED +spsm,CSR,int32,float64,non,triton,SUPPORTED +spsm,CSR,int64,float32,non,triton,SUPPORTED +spsm,CSR,int64,float64,non,triton,SUPPORTED +spsv,COO,int32,bfloat16,non,triton,SUPPORTED +spsv,COO,int32,bfloat16,trans,triton,SUPPORTED +spsv,COO,int32,float32,non,triton,SUPPORTED +spsv,COO,int32,float32,trans,triton,SUPPORTED +spsv,COO,int32,float64,non,triton,SUPPORTED +spsv,COO,int32,float64,trans,triton,SUPPORTED +spsv,COO,int32,complex32,non,triton,SUPPORTED +spsv,COO,int32,complex32,trans,triton,SUPPORTED +spsv,COO,int32,complex64,non,triton,SUPPORTED +spsv,COO,int32,complex64,trans,triton,SUPPORTED +spsv,COO,int64,bfloat16,non,triton,SUPPORTED +spsv,COO,int64,bfloat16,trans,triton,SUPPORTED +spsv,COO,int64,float32,non,triton,SUPPORTED +spsv,COO,int64,float32,trans,triton,SUPPORTED +spsv,COO,int64,float64,non,triton,SUPPORTED +spsv,COO,int64,float64,trans,triton,SUPPORTED +spsv,COO,int64,complex32,non,triton,SUPPORTED +spsv,COO,int64,complex32,trans,triton,SUPPORTED +spsv,COO,int64,complex64,non,triton,SUPPORTED +spsv,COO,int64,complex64,trans,triton,SUPPORTED +spsv,CSR,int32,bfloat16,non,triton,SUPPORTED +spsv,CSR,int32,bfloat16,trans,triton,SUPPORTED +spsv,CSR,int32,float32,non,triton,SUPPORTED +spsv,CSR,int32,float32,trans,triton,SUPPORTED +spsv,CSR,int32,float64,non,triton,SUPPORTED +spsv,CSR,int32,float64,trans,triton,SUPPORTED +spsv,CSR,int32,complex32,non,triton,SUPPORTED +spsv,CSR,int32,complex32,trans,triton,SUPPORTED +spsv,CSR,int32,complex64,non,triton,SUPPORTED +spsv,CSR,int32,complex64,trans,triton,SUPPORTED +spsv,CSR,int64,bfloat16,non,triton,SUPPORTED +spsv,CSR,int64,bfloat16,trans,triton,SUPPORTED +spsv,CSR,int64,float32,non,triton,SUPPORTED +spsv,CSR,int64,float32,trans,triton,SUPPORTED +spsv,CSR,int64,float64,non,triton,SUPPORTED +spsv,CSR,int64,float64,trans,triton,SUPPORTED +spsv,CSR,int64,complex32,non,triton,SUPPORTED +spsv,CSR,int64,complex32,trans,triton,SUPPORTED +spsv,CSR,int64,complex64,non,triton,SUPPORTED +spsv,CSR,int64,complex64,trans,triton,SUPPORTED diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f54de0b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,30 @@ +[build-system] +requires = ["setuptools>=61", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "flagsparse" +version = "1.0.0" +description = "FlagSparse - GPU sparse operations package" +readme = "README.md" +requires-python = ">=3.8" +dependencies = [] +authors = [{ name = "Your Name", email = "your.email@example.com" }] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12" +] + +[project.urls] +Homepage = "https://github.com/yourusername/flagsparse" +Repository = "https://github.com/yourusername/flagsparse" + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..36c01fe --- /dev/null +++ b/pytest.ini @@ -0,0 +1,13 @@ +[pytest] +testpaths = tests +pythonpath = src . +python_files = test_*.py +python_functions = test_* +addopts = -v +filterwarnings = ignore::pytest.PytestUnknownMarkWarning +markers = + gather: gather op accuracy (tests/pytest) + scatter: scatter op accuracy (tests/pytest) + spmv_csr: CSR SpMV accuracy (tests/pytest) + spmv_coo: COO SpMV accuracy (tests/pytest) + spsv: CSR SpSV accuracy only (tests/pytest; no COO SpSV) diff --git a/run_flagsparse_pytest.py b/run_flagsparse_pytest.py new file mode 100644 index 0000000..489de8e --- /dev/null +++ b/run_flagsparse_pytest.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python3 +"""Run FlagSparse pytest accuracy suites per operator and summarize results.""" + +from __future__ import annotations + +import argparse +import csv +import datetime as _dt +import json +import os +import re +import shlex +import subprocess +import sys +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from pathlib import Path + +try: + from openpyxl import Workbook +except Exception: + Workbook = None + + +DEFAULT_OPS = ( + "gather", + "scatter", + "spmv_csr", + "spmv_coo", + "spmv_coo_tocsr", + "spmm_csr", + "spmm_csr_opt", + "spmm_coo", + "spsv_csr", + "spsv_coo", + "spsm_csr", + "spsm_coo", + "spgemm_csr", + "sddmm_csr", +) + +ANSI_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") +SUMMARY_RE = re.compile(r"(\d+)\s+([A-Za-z_]+)") +SUMMARY_LOCK = threading.Lock() + + +def now_ts() -> str: + return _dt.datetime.now().strftime("%Y%m%d_%H%M%S") + + +def ensure_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + + +def parse_pytest_summary(text: str) -> dict[str, int]: + clean = ANSI_RE.sub("", text) + counts = {"passed": 0, "failed": 0, "skipped": 0, "errors": 0} + for match in SUMMARY_RE.finditer(clean): + key = match.group(2).lower() + if key in counts: + counts[key] = int(match.group(1)) + counts["total"] = counts["passed"] + counts["failed"] + counts["skipped"] + return counts + + +def status_from_counts(counts: dict[str, int], returncode: int) -> str: + if returncode not in (0, 5) and not any( + counts[key] for key in ("passed", "failed", "skipped", "errors") + ): + return "CRASH" + if counts["failed"] or counts["errors"] or returncode not in (0, 5): + return "FAIL" + if counts["passed"]: + return "PASS" + if counts["skipped"]: + return "SKIP" + return "NO_TESTS" + + +def read_ops(op_list: str | None, ops_arg: str | None) -> list[str]: + if ops_arg: + return [op.strip() for op in ops_arg.split(",") if op.strip()] + if op_list: + with open(op_list, encoding="utf-8") as handle: + return [ + line.strip() + for line in handle + if line.strip() and not line.lstrip().startswith("#") + ] + return list(DEFAULT_OPS) + + +def run_one_op( + project_root: Path, + op: str, + gpu_id: int, + mode: str, + results_dir: Path, + extra_pytest_args: list[str], +) -> dict[str, object]: + op_dir = results_dir / op + ensure_dir(op_dir) + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + cmd = [ + sys.executable, + "-m", + "pytest", + "tests/pytest", + "-m", + op, + "--mode", + mode, + "-vs", + "-p", + "no:cacheprovider", + *extra_pytest_args, + ] + proc = subprocess.run( + cmd, + cwd=str(project_root), + env=env, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + combined = (proc.stdout or "") + "\n" + (proc.stderr or "") + log_path = op_dir / "accuracy.log" + log_path.write_text(combined, encoding="utf-8") + + counts = parse_pytest_summary(combined) + return { + "operator": op, + "gpu": gpu_id, + "status": status_from_counts(counts, proc.returncode), + "returncode": proc.returncode, + "log_path": str(log_path), + **counts, + } + + +def run_gpu_ops( + project_root: Path, + gpu_id: int, + ops: list[str], + mode: str, + results_dir: Path, + extra_pytest_args: list[str], + results: list[dict[str, object]], +) -> None: + for op in ops: + result = run_one_op( + project_root, + op, + gpu_id, + mode, + results_dir, + extra_pytest_args, + ) + with SUMMARY_LOCK: + results.append(result) + write_summary(results, results_dir) + print( + f"[{result['status']}] {result['operator']} " + f"gpu={result['gpu']} passed={result['passed']} failed={result['failed']} " + f"skipped={result['skipped']} errors={result['errors']}" + ) + + +def write_summary(results: list[dict[str, object]], results_dir: Path) -> None: + ordered = sorted(results, key=lambda item: str(item["operator"])) + json_path = results_dir / "summary.json" + json_path.write_text(json.dumps(ordered, indent=2), encoding="utf-8") + + csv_path = results_dir / "summary.csv" + headers = [ + "operator", + "gpu", + "status", + "passed", + "failed", + "skipped", + "errors", + "total", + "returncode", + "log_path", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=headers) + writer.writeheader() + for row in ordered: + writer.writerow({key: row.get(key, "") for key in headers}) + + if Workbook is None: + return + wb = Workbook() + ws = wb.active + ws.title = "Summary" + ws.append(headers) + for row in ordered: + ws.append([row.get(key, "") for key in headers]) + wb.save(results_dir / "summary.xlsx") + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--op-list", default=None, help="File with one pytest marker per line.") + parser.add_argument("--ops", default=None, help="Comma-separated pytest markers; overrides --op-list.") + parser.add_argument("--gpus", default="0", help="Comma-separated GPU ids for CUDA_VISIBLE_DEVICES.") + parser.add_argument("--mode", default="quick", choices=("quick", "normal")) + parser.add_argument("--results-dir", default=None) + parser.add_argument( + "--pytest-args", + default="", + help="Extra pytest args appended to every per-op invocation.", + ) + args = parser.parse_args() + + project_root = Path(__file__).resolve().parent + ops = read_ops(args.op_list, args.ops) + if not ops: + raise SystemExit("no operators to run") + gpus = [int(item.strip()) for item in args.gpus.split(",") if item.strip()] + if not gpus: + raise SystemExit("no GPUs provided") + results_dir = ( + Path(args.results_dir).resolve() + if args.results_dir + else project_root / f"pytest_results_{now_ts()}" + ) + ensure_dir(results_dir) + extra_pytest_args = shlex.split(args.pytest_args) if args.pytest_args else [] + + tasks = {gpu: [] for gpu in gpus} + for index, op in enumerate(ops): + tasks[gpus[index % len(gpus)]].append(op) + + results: list[dict[str, object]] = [] + with ThreadPoolExecutor(max_workers=len(gpus)) as executor: + futures = [] + for gpu, gpu_ops in tasks.items(): + if not gpu_ops: + continue + futures.append( + executor.submit( + run_gpu_ops, + project_root, + gpu, + gpu_ops, + args.mode, + results_dir, + extra_pytest_args, + results, + ) + ) + for future in as_completed(futures): + future.result() + + write_summary(results, results_dir) + return 1 if any(result["status"] in ("FAIL", "NO_TESTS", "CRASH") for result in results) else 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b930a79 --- /dev/null +++ b/setup.py @@ -0,0 +1,12 @@ +from setuptools import find_packages, setup + +setup( + name="flagsparse", + version="1.0.0", + description="FlagSparse - GPU sparse operations package", + package_dir={"": "src"}, + packages=find_packages(where="src"), + python_requires=">=3.8", + install_requires=[], + include_package_data=True, +) diff --git a/src/flagsparse/__init__.py b/src/flagsparse/__init__.py new file mode 100644 index 0000000..02b5d76 --- /dev/null +++ b/src/flagsparse/__init__.py @@ -0,0 +1,148 @@ +"""FlagSparse package.""" + +__version__ = "1.0.0" + +__all__ = [ + "flagsparse_gather", + "flagsparse_scatter", + "pytorch_index_gather", + "pytorch_index_scatter", + "cusparse_spmv_gather", + "cusparse_spmv_scatter", + "benchmark_gather_case", + "benchmark_scatter_case", + "benchmark_performance", + "comprehensive_gather_test", + "comprehensive_scatter_test", + "PreparedCoo", + "PreparedCsrSpmv", + "PreparedCsrSpmmOpt", + "SpGEMMPrepared", + "SDDMMPrepared", + "prepare_spmv_csr", + "prepare_spmm_csr_opt", + "prepare_spmv_coo", + "prepare_spmv_coo_tocsr", + "flagsparse_spmv_csr", + "flagsparse_spmv_coo", + "flagsparse_spmv_coo_tocsr", + "flagsparse_spmm_csr_opt", + "flagsparse_spsv_csr", + "flagsparse_spsv_coo", + "flagsparse_spsm_csr", + "flagsparse_spsm_coo", + "benchmark_spsm_case", + "prepare_spgemm_csr", + "flagsparse_spgemm_csr", + "prepare_sddmm_csr", + "flagsparse_sddmm_csr", + "flagsparse_spmm_csr", + "flagsparse_spmm_coo", + "benchmark_spmm_case", + "benchmark_spmm_opt_case", + "benchmark_spgemm_case", + "benchmark_sddmm_case", + "comprehensive_spmm_test", + "comprehensive_spsm_test", + "benchmark_spmv_case", + "create_csr_matrix", + "create_coo_matrix", + "create_csc_matrix", + "create_bsr_matrix", + "create_sell_matrix", + "create_blocked_ell_matrix", + "coo_to_csr", + "coo_to_csc", + "coo_to_bsr", + "coo_to_sell", + "coo_to_blocked_ell", + "CSRMatrix", + "COOMatrix", + "CSCMatrix", + "BSRMatrix", + "SELLMatrix", + "BLOCKEDELLMatrix", + "generate_random_sparse_matrix", + "read_mtx_file", +] + +_OPS_EXPORTS = { + "flagsparse_gather", + "flagsparse_scatter", + "pytorch_index_gather", + "pytorch_index_scatter", + "cusparse_spmv_gather", + "cusparse_spmv_scatter", + "benchmark_gather_case", + "benchmark_scatter_case", + "benchmark_performance", + "comprehensive_gather_test", + "comprehensive_scatter_test", + "PreparedCoo", + "PreparedCsrSpmv", + "PreparedCsrSpmmOpt", + "SpGEMMPrepared", + "SDDMMPrepared", + "prepare_spmv_csr", + "prepare_spmm_csr_opt", + "prepare_spmv_coo", + "prepare_spmv_coo_tocsr", + "flagsparse_spmv_csr", + "flagsparse_spmv_coo", + "flagsparse_spmv_coo_tocsr", + "flagsparse_spmm_csr_opt", + "flagsparse_spsv_csr", + "flagsparse_spsv_coo", + "flagsparse_spsm_csr", + "flagsparse_spsm_coo", + "benchmark_spsm_case", + "prepare_spgemm_csr", + "flagsparse_spgemm_csr", + "prepare_sddmm_csr", + "flagsparse_sddmm_csr", + "flagsparse_spmm_csr", + "flagsparse_spmm_coo", + "benchmark_spmm_case", + "benchmark_spmm_opt_case", + "benchmark_spgemm_case", + "benchmark_sddmm_case", + "comprehensive_spmm_test", + "benchmark_spmv_case", + "comprehensive_spsm_test", +} + +_FORMAT_EXPORTS = { + "create_csr_matrix", + "create_coo_matrix", + "create_csc_matrix", + "create_bsr_matrix", + "create_sell_matrix", + "create_blocked_ell_matrix", + "coo_to_csr", + "coo_to_csc", + "coo_to_bsr", + "coo_to_sell", + "coo_to_blocked_ell", + "CSRMatrix", + "COOMatrix", + "CSCMatrix", + "BSRMatrix", + "SELLMatrix", + "BLOCKEDELLMatrix", + "generate_random_sparse_matrix", + "read_mtx_file", +} + + +def __getattr__(name): + if name in _OPS_EXPORTS: + from . import sparse_operations as _ops + return getattr(_ops, name) + if name in _FORMAT_EXPORTS: + from . import sparse_formats as _formats + return getattr(_formats, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return sorted(set(globals()) | _OPS_EXPORTS | _FORMAT_EXPORTS) diff --git a/src/flagsparse/sparse_formats.py b/src/flagsparse/sparse_formats.py new file mode 100644 index 0000000..5646f86 --- /dev/null +++ b/src/flagsparse/sparse_formats.py @@ -0,0 +1,670 @@ +"""Sparse matrix formats aligned with CuPy/cupyx.""" + +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse +except ImportError as exc: + raise ImportError( + "CuPy is required for sparse format utilities. " + "Install a CUDA-matched wheel, for example: pip install cupy-cuda12x" + ) from exc + +try: + import torch +except ImportError: + torch = None + + +def _resolve_dtype(dtype): + if dtype is None: + return cp.dtype(cp.float32) + if torch is not None and isinstance(dtype, torch.dtype): + try: + cupy_bfloat16 = cp.dtype("bfloat16") + except TypeError: + cupy_bfloat16 = cp.float32 + mapping = { + torch.float16: cp.float16, + torch.bfloat16: cupy_bfloat16, + torch.float32: cp.float32, + torch.float64: cp.float64, + torch.complex64: cp.complex64, + torch.complex128: cp.complex128, + torch.int32: cp.int32, + torch.int64: cp.int64, + } + if dtype not in mapping: + raise TypeError(f"Unsupported torch dtype: {dtype}") + return cp.dtype(mapping[dtype]) + return cp.dtype(dtype) + + +def _to_cupy_array(x, dtype=None): + target_dtype = _resolve_dtype(dtype) if dtype is not None else None + if isinstance(x, cp.ndarray): + return x.astype(target_dtype, copy=False) if target_dtype else x + if torch is not None and torch.is_tensor(x): + arr = cp.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + return arr.astype(target_dtype, copy=False) if target_dtype else arr + arr = cp.asarray(x) + return arr.astype(target_dtype, copy=False) if target_dtype else arr + + +def _random_values(nnz, dtype): + dtype = cp.dtype(dtype) + if dtype == cp.dtype(cp.complex64): + real = cp.random.standard_normal(nnz, dtype=cp.float32) + imag = cp.random.standard_normal(nnz, dtype=cp.float32) + return (real + 1j * imag).astype(dtype, copy=False) + if dtype == cp.dtype(cp.complex128): + real = cp.random.standard_normal(nnz, dtype=cp.float64) + imag = cp.random.standard_normal(nnz, dtype=cp.float64) + return (real + 1j * imag).astype(dtype, copy=False) + return cp.random.standard_normal(nnz).astype(dtype, copy=False) + + +class CSRMatrix: + def __init__(self, values, indices=None, indptr=None, shape=None, dtype=None): + if isinstance(values, cpx_sparse.csr_matrix): + self.matrix = values + else: + resolved_dtype = _resolve_dtype(dtype) + data = _to_cupy_array(values, dtype=resolved_dtype) + cols = _to_cupy_array(indices, dtype=cp.int64) + row_ptr = _to_cupy_array(indptr, dtype=cp.int64) + self.matrix = cpx_sparse.csr_matrix( + (data, cols, row_ptr), shape=shape, dtype=resolved_dtype + ) + + @property + def values(self): + return self.matrix.data + + @property + def indices(self): + return self.matrix.indices + + @property + def indptr(self): + return self.matrix.indptr + + @property + def shape(self): + return self.matrix.shape + + @property + def dtype(self): + return self.matrix.dtype + + def to_dense(self): + return self.matrix.toarray() + + def to_coo(self): + return COOMatrix(self.matrix.tocoo()) + + def __repr__(self): + return f"CSRMatrix(shape={self.shape}, nnz={self.matrix.nnz})" + + +class CSCMatrix: + """Compressed Sparse Column matrix (CuPy-style).""" + def __init__(self, values, indices=None, indptr=None, shape=None, dtype=None): + if isinstance(values, cpx_sparse.csc_matrix): + self.matrix = values + else: + resolved_dtype = _resolve_dtype(dtype) + data = _to_cupy_array(values, dtype=resolved_dtype) + rows = _to_cupy_array(indices, dtype=cp.int64) + col_ptr = _to_cupy_array(indptr, dtype=cp.int64) + self.matrix = cpx_sparse.csc_matrix( + (data, rows, col_ptr), shape=shape, dtype=resolved_dtype + ) + + @property + def values(self): + return self.matrix.data + + @property + def indices(self): + return self.matrix.indices + + @property + def indptr(self): + return self.matrix.indptr + + @property + def shape(self): + return self.matrix.shape + + @property + def dtype(self): + return self.matrix.dtype + + def to_dense(self): + return self.matrix.toarray() + + def to_coo(self): + return COOMatrix(self.matrix.tocoo()) + + def __repr__(self): + return f"CSCMatrix(shape={self.shape}, nnz={self.matrix.nnz})" + + +class BSRMatrix: + """Block Sparse Row matrix (CuPy/SciPy-style). blocksize=(R,C).""" + def __init__(self, data, indices=None, indptr=None, shape=None, blocksize=None, dtype=None): + if isinstance(data, cpx_sparse.bsr_matrix): + self.matrix = data + else: + resolved_dtype = _resolve_dtype(dtype) + data_arr = _to_cupy_array(data, dtype=resolved_dtype) + inds = _to_cupy_array(indices, dtype=cp.int64) + ptr = _to_cupy_array(indptr, dtype=cp.int64) + if blocksize is None: + raise ValueError("BSRMatrix requires blocksize (R, C)") + if isinstance(blocksize, int): + blocksize = (blocksize, blocksize) + self.matrix = cpx_sparse.bsr_matrix( + (data_arr, inds, ptr), shape=shape, blocksize=blocksize + ) + + @property + def data(self): + return self.matrix.data + + @property + def indices(self): + return self.matrix.indices + + @property + def indptr(self): + return self.matrix.indptr + + @property + def blocksize(self): + return self.matrix.blocksize + + @property + def shape(self): + return self.matrix.shape + + @property + def dtype(self): + return self.matrix.dtype + + def to_dense(self): + return self.matrix.toarray() + + def to_coo(self): + return COOMatrix(self.matrix.tocoo()) + + def __repr__(self): + return f"BSRMatrix(shape={self.shape}, blocksize={self.blocksize}, nnz_blocks={self.indices.size})" + + +class SELLMatrix: + """ + Sliced ELLPACK format. Stores: values, indices (column), slice_ptr, rows_per_slice. + CuPy-compatible interface; backend uses CuPy arrays. + """ + def __init__(self, values, indices, slice_ptr, rows_per_slice, shape, dtype=None): + self._values = _to_cupy_array(values, dtype=_resolve_dtype(dtype)) + self._indices = _to_cupy_array(indices, dtype=cp.int64) + self._slice_ptr = _to_cupy_array(slice_ptr, dtype=cp.int64) + self._rows_per_slice = _to_cupy_array(rows_per_slice, dtype=cp.int64) + self._shape = tuple(shape) + + @property + def values(self): + return self._values + + @property + def indices(self): + return self._indices + + @property + def slice_ptr(self): + return self._slice_ptr + + @property + def rows_per_slice(self): + return self._rows_per_slice + + @property + def shape(self): + return self._shape + + @property + def dtype(self): + return self._values.dtype + + def to_dense(self): + return self.to_coo().to_dense() + + def to_coo(self): + return _sell_to_coo(self) + + def __repr__(self): + nnz = int(self._values.size) + return f"SELLMatrix(shape={self.shape}, nnz={nnz}, n_slices={self._slice_ptr.size - 1})" + + +class BLOCKEDELLMatrix: + """ + Blocked ELL format. data shape (n_block_rows, max_blocks_per_row, r, c), + indices shape (n_block_rows, max_blocks_per_row). CuPy arrays. + """ + def __init__(self, data, indices, block_shape, shape, dtype=None): + self._data = _to_cupy_array(data, dtype=_resolve_dtype(dtype)) + self._indices = _to_cupy_array(indices, dtype=cp.int64) + self._block_shape = tuple(block_shape) + self._shape = tuple(shape) + + @property + def data(self): + return self._data + + @property + def indices(self): + return self._indices + + @property + def block_shape(self): + return self._block_shape + + @property + def shape(self): + return self._shape + + @property + def dtype(self): + return self._data.dtype + + def to_dense(self): + return self.to_coo().to_dense() + + def to_coo(self): + return _blocked_ell_to_coo(self) + + def __repr__(self): + return f"BLOCKEDELLMatrix(shape={self.shape}, block_shape={self._block_shape})" + + +class COOMatrix: + def __init__(self, row_indices, col_indices=None, values=None, shape=None, dtype=None): + if isinstance(row_indices, cpx_sparse.coo_matrix): + self.matrix = row_indices + else: + resolved_dtype = _resolve_dtype(dtype) + rows = _to_cupy_array(row_indices, dtype=cp.int64) + cols = _to_cupy_array(col_indices, dtype=cp.int64) + data = _to_cupy_array(values, dtype=resolved_dtype) + self.matrix = cpx_sparse.coo_matrix( + (data, (rows, cols)), shape=shape, dtype=resolved_dtype + ) + + @property + def row_indices(self): + return self.matrix.row + + @property + def col_indices(self): + return self.matrix.col + + @property + def values(self): + return self.matrix.data + + @property + def shape(self): + return self.matrix.shape + + @property + def dtype(self): + return self.matrix.dtype + + def to_dense(self): + return self.matrix.toarray() + + def to_csr(self): + return CSRMatrix(self.matrix.tocsr()) + + def to_csc(self): + return CSCMatrix(self.matrix.tocsc()) + + def to_bsr(self, blocksize=None): + if blocksize is None: + blocksize = (1, 1) + return BSRMatrix(self.matrix.tobsr(blocksize=blocksize)) + + def to_sell(self, slice_size=32): + return coo_to_sell(self, slice_size=slice_size) + + def to_blocked_ell(self, block_shape): + return coo_to_blocked_ell(self, block_shape=block_shape) + + def __repr__(self): + return f"COOMatrix(shape={self.shape}, nnz={self.matrix.nnz})" + + +def _sell_to_coo(sell_mat): + n_rows, n_cols = sell_mat.shape + slice_ptr = sell_mat.slice_ptr + rows_per_slice = sell_mat.rows_per_slice + values = sell_mat.values + indices = sell_mat.indices + n_slices = int(slice_ptr.size) - 1 + rows_list = [] + cols_list = [] + vals_list = [] + base_row = 0 + for s in range(n_slices): + start = int(slice_ptr[s]) + end = int(slice_ptr[s + 1]) + rps = int(rows_per_slice[s]) + if rps <= 0: + base_row += rps + continue + max_nnz = (end - start) // rps + for r in range(rps): + row = base_row + r + for k in range(max_nnz): + idx = start + r * max_nnz + k + col = int(indices[idx]) + val = values[idx] + nonzero = (val != 0).item() if hasattr(val, "item") else (val != 0) + if row < n_rows and 0 <= col < n_cols and nonzero: + rows_list.append(row) + cols_list.append(col) + vals_list.append(val) + base_row += rps + if not rows_list: + return COOMatrix( + cp.array([], dtype=cp.int64), + cp.array([], dtype=cp.int64), + cp.array([], dtype=sell_mat.dtype), + sell_mat.shape, + dtype=sell_mat.dtype, + ) + return COOMatrix( + cp.asarray(rows_list, dtype=cp.int64), + cp.asarray(cols_list, dtype=cp.int64), + cp.asarray(vals_list, dtype=sell_mat.dtype), + sell_mat.shape, + dtype=sell_mat.dtype, + ) + + +def _blocked_ell_to_coo(be_mat): + data = be_mat.data + indices = be_mat.indices + br, bc = be_mat.block_shape + n_block_rows, max_blocks = indices.shape + n_rows, n_cols = be_mat.shape + rows_list = [] + cols_list = [] + vals_list = [] + for i in range(n_block_rows): + for j in range(max_blocks): + bcol = int(indices[i, j]) + if bcol < 0: + continue + block = data[i, j] + for di in range(br): + for dj in range(bc): + r = i * br + di + c = bcol * bc + dj + if r < n_rows and c < n_cols: + v = block[di, dj] + nonzero = (v != 0).item() if hasattr(v, "item") else (v != 0) + if nonzero: + rows_list.append(r) + cols_list.append(c) + vals_list.append(v) + if not rows_list: + return COOMatrix( + cp.array([], dtype=cp.int64), + cp.array([], dtype=cp.int64), + cp.array([], dtype=be_mat.dtype), + be_mat.shape, + dtype=be_mat.dtype, + ) + return COOMatrix( + cp.asarray(rows_list, dtype=cp.int64), + cp.asarray(cols_list, dtype=cp.int64), + cp.asarray(vals_list, dtype=be_mat.dtype), + be_mat.shape, + dtype=be_mat.dtype, + ) + + +def _coo_to_sell_impl(rows, cols, data, shape, slice_size): + n_rows, n_cols = shape + if slice_size is None or slice_size <= 0: + slice_size = 32 + slice_size = int(slice_size) + sort_idx = cp.lexsort((cols, rows)) + rows = rows[sort_idx] + cols = cols[sort_idx] + data = data[sort_idx] + nnz_per_row = cp.bincount(rows, minlength=n_rows) + n_slices = (n_rows + slice_size - 1) // slice_size + slice_ptr = cp.zeros(n_slices + 1, dtype=cp.int64) + rows_per_slice = cp.zeros(n_slices, dtype=cp.int64) + total_entries = 0 + for s in range(n_slices): + r0 = s * slice_size + r1 = min(r0 + slice_size, n_rows) + rps = r1 - r0 + rows_per_slice[s] = rps + if rps > 0: + max_nnz = int(cp.max(nnz_per_row[r0:r1])) + else: + max_nnz = 0 + total_entries += rps * max_nnz + slice_ptr[s + 1] = total_entries + values = cp.zeros(total_entries, dtype=data.dtype) + indices = cp.zeros(total_entries, dtype=cp.int64) + row_start = cp.zeros(n_rows + 1, dtype=cp.int64) + row_start[1:] = cp.cumsum(nnz_per_row) + base = 0 + for s in range(n_slices): + r0 = s * slice_size + r1 = min(r0 + slice_size, n_rows) + rps = int(rows_per_slice[s]) + if rps == 0: + continue + max_nnz = (int(slice_ptr[s + 1]) - int(slice_ptr[s])) // rps + for r in range(rps): + row = r0 + r + start = int(row_start[row]) + end = int(row_start[row + 1]) + nnz = end - start + dst_start = base + r * max_nnz + if nnz > 0: + values[dst_start : dst_start + nnz] = data[start:end] + indices[dst_start : dst_start + nnz] = cols[start:end] + base = int(slice_ptr[s + 1]) + return SELLMatrix(values, indices, slice_ptr, rows_per_slice, shape, dtype=data.dtype) + + +def _coo_to_blocked_ell_impl(rows, cols, data, shape, block_shape): + br, bc = block_shape + n_rows, n_cols = shape + if n_rows % br != 0 or n_cols % bc != 0: + raise ValueError( + f"shape {shape} must be divisible by block_shape {block_shape}" + ) + n_block_rows = n_rows // br + n_block_cols = n_cols // bc + block_rows = rows // br + block_cols = cols // bc + in_block_r = rows % br + in_block_c = cols % bc + nnz = rows.size + blocks = {} + for k in range(nnz): + i = int(block_rows[k]) + j = int(block_cols[k]) + ir = int(in_block_r[k]) + ic = int(in_block_c[k]) + key = (i, j) + if key not in blocks: + blocks[key] = cp.zeros((br, bc), dtype=data.dtype) + blocks[key][ir, ic] += data[k] + max_blocks_per_row = 0 + for i in range(n_block_rows): + count = sum(1 for (bi, _) in blocks if bi == i) + max_blocks_per_row = max(max_blocks_per_row, count) + if max_blocks_per_row == 0: + max_blocks_per_row = 1 + data_out = cp.zeros((n_block_rows, max_blocks_per_row, br, bc), dtype=data.dtype) + indices_out = cp.full((n_block_rows, max_blocks_per_row), -1, dtype=cp.int64) + for i in range(n_block_rows): + cols_in_row = sorted([j for (bi, j) in blocks if bi == i]) + for t, j in enumerate(cols_in_row): + data_out[i, t] = blocks[(i, j)] + indices_out[i, t] = j + return BLOCKEDELLMatrix( + data_out, indices_out, block_shape, shape, dtype=data.dtype + ) + + +def create_csr_matrix(values, indices, indptr, shape, dtype=None): + return CSRMatrix(values, indices, indptr, shape, dtype=dtype) + + +def create_coo_matrix(row_indices, col_indices, values, shape, dtype=None): + return COOMatrix(row_indices, col_indices, values, shape, dtype=dtype) + + +def coo_to_csr(coo_matrix): + if not isinstance(coo_matrix, COOMatrix): + raise TypeError("coo_matrix must be an instance of COOMatrix") + return coo_matrix.to_csr() + + +def coo_to_csc(coo_matrix): + """Convert COO to CSC (CuPy-style: .tocsc()).""" + if not isinstance(coo_matrix, COOMatrix): + raise TypeError("coo_matrix must be an instance of COOMatrix") + return coo_matrix.to_csc() + + +def coo_to_bsr(coo_matrix, blocksize=None): + """Convert COO to BSR. blocksize: (R, C) or int for square block.""" + if not isinstance(coo_matrix, COOMatrix): + raise TypeError("coo_matrix must be an instance of COOMatrix") + return coo_matrix.to_bsr(blocksize=blocksize) + + +def coo_to_sell(coo_matrix, slice_size=32): + """Convert COO to SELL (Sliced ELLPACK). slice_size: rows per slice.""" + if not isinstance(coo_matrix, COOMatrix): + raise TypeError("coo_matrix must be an instance of COOMatrix") + rows = coo_matrix.row_indices + cols = coo_matrix.col_indices + data = coo_matrix.values + return _coo_to_sell_impl(rows, cols, data, coo_matrix.shape, slice_size) + + +def coo_to_blocked_ell(coo_matrix, block_shape): + """Convert COO to BLOCKED-ELL. block_shape: (r, c) block dimensions.""" + if not isinstance(coo_matrix, COOMatrix): + raise TypeError("coo_matrix must be an instance of COOMatrix") + rows = coo_matrix.row_indices + cols = coo_matrix.col_indices + data = coo_matrix.values + return _coo_to_blocked_ell_impl( + rows, cols, data, coo_matrix.shape, block_shape + ) + + +def create_csc_matrix(values, indices, indptr, shape, dtype=None): + """Create CSC from (data, row_indices, col_ptr), CuPy-style.""" + return CSCMatrix(values, indices, indptr, shape, dtype=dtype) + + +def create_bsr_matrix(data, indices, indptr, shape, blocksize, dtype=None): + """Create BSR from (data, indices, indptr), shape, blocksize=(R,C).""" + return BSRMatrix(data, indices, indptr, shape, blocksize=blocksize, dtype=dtype) + + +def create_sell_matrix(values, indices, slice_ptr, rows_per_slice, shape, dtype=None): + """Create SELL from values, indices, slice_ptr, rows_per_slice, shape.""" + return SELLMatrix( + values, indices, slice_ptr, rows_per_slice, shape, dtype=dtype + ) + + +def create_blocked_ell_matrix(data, indices, block_shape, shape, dtype=None): + """Create BLOCKED-ELL from data, indices, block_shape, shape.""" + return BLOCKEDELLMatrix(data, indices, block_shape, shape, dtype=dtype) + + +def generate_random_sparse_matrix( + n_rows, n_cols, density=0.1, dtype=cp.float32, device=None +): + if device is not None: + cp.cuda.Device(device).use() + total = int(n_rows) * int(n_cols) + nnz = max(0, int(total * float(density))) + nnz = min(nnz, total) + if nnz == 0: + rows = cp.asarray([], dtype=cp.int64) + cols = cp.asarray([], dtype=cp.int64) + vals = cp.asarray([], dtype=_resolve_dtype(dtype)) + else: + linear = cp.random.permutation(total)[:nnz] + rows = linear // n_cols + cols = linear % n_cols + vals = _random_values(nnz, _resolve_dtype(dtype)) + coo_matrix = COOMatrix(rows, cols, vals, (n_rows, n_cols), dtype=dtype) + csr_matrix = coo_to_csr(coo_matrix) + return coo_matrix, csr_matrix + + +def read_mtx_file(file_path, dtype=cp.float32, device=None): + if device is not None: + cp.cuda.Device(device).use() + + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + data_lines = [] + header_info = None + for line in lines: + line = line.strip() + if line.startswith("%"): + continue + if not header_info and line: + header_parts = line.split() + n_rows = int(header_parts[0]) + n_cols = int(header_parts[1]) + nnz = int(header_parts[2]) + header_info = (n_rows, n_cols, nnz) + continue + if line: + data_lines.append(line) + + if header_info is None: + raise ValueError("Could not parse matrix dimensions from .mtx file") + + n_rows, n_cols, nnz = header_info + rows_host = [] + cols_host = [] + vals_host = [] + for line in data_lines[:nnz]: + parts = line.split() + if len(parts) >= 3: + rows_host.append(int(parts[0]) - 1) + cols_host.append(int(parts[1]) - 1) + vals_host.append(float(parts[2])) + + resolved_dtype = _resolve_dtype(dtype) + rows = cp.asarray(rows_host, dtype=cp.int64) + cols = cp.asarray(cols_host, dtype=cp.int64) + vals = cp.asarray(vals_host, dtype=resolved_dtype) + coo_matrix = COOMatrix(rows, cols, vals, (n_rows, n_cols), dtype=resolved_dtype) + csr_matrix = coo_to_csr(coo_matrix) + return coo_matrix, csr_matrix diff --git a/src/flagsparse/sparse_operations/__init__.py b/src/flagsparse/sparse_operations/__init__.py new file mode 100644 index 0000000..0fac868 --- /dev/null +++ b/src/flagsparse/sparse_operations/__init__.py @@ -0,0 +1,100 @@ +"""FlagSparse sparse operations (gather, scatter, SpMV, SpMM, SpGEMM, SDDMM, SpSM).""" + +from ._common import SUPPORTED_INDEX_DTYPES, SUPPORTED_VALUE_DTYPES, cp, cpx_sparse +from .benchmarks import ( + benchmark_gather_case, + benchmark_performance, + benchmark_scatter_case, + benchmark_sddmm_case, + benchmark_spgemm_case, + benchmark_spmm_case, + benchmark_spmm_opt_case, + benchmark_spmv_case, + benchmark_spsm_case, + comprehensive_gather_test, + comprehensive_scatter_test, + comprehensive_spmm_test, + comprehensive_spsm_test, +) +from .gather_scatter import ( + cusparse_spmv_gather, + cusparse_spmv_scatter, + flagsparse_gather, + flagsparse_scatter, + pytorch_index_gather, + pytorch_index_scatter, + triton_cusparse_gather, + triton_cusparse_scatter, +) +from .spmv_csr import ( + PreparedCsrSpmv, + flagsparse_spmv_coo_tocsr, + flagsparse_spmv_csr, + prepare_spmv_coo_tocsr, + prepare_spmv_csr, +) +from .spmv_coo import ( + PreparedCoo, + flagsparse_spmv_coo, + prepare_spmv_coo, +) +from .spmm_csr import ( + PreparedCsrSpmmOpt, + benchmark_spmm_opt_case, + flagsparse_spmm_csr, + flagsparse_spmm_csr_opt, + prepare_spmm_csr_opt, +) +from .spmm_coo import flagsparse_spmm_coo +from .spgemm_csr import SpGEMMPrepared, flagsparse_spgemm_csr, prepare_spgemm_csr +from .sddmm_csr import SDDMMPrepared, flagsparse_sddmm_csr, prepare_sddmm_csr +from .spsv import flagsparse_spsv_coo, flagsparse_spsv_csr +from .spsm import flagsparse_spsm_coo, flagsparse_spsm_csr + +__all__ = [ + "PreparedCoo", + "PreparedCsrSpmv", + "PreparedCsrSpmmOpt", + "SDDMMPrepared", + "SpGEMMPrepared", + "SUPPORTED_INDEX_DTYPES", + "SUPPORTED_VALUE_DTYPES", + "benchmark_gather_case", + "benchmark_performance", + "benchmark_scatter_case", + "benchmark_sddmm_case", + "benchmark_spgemm_case", + "benchmark_spmm_case", + "benchmark_spmm_opt_case", + "benchmark_spmv_case", + "benchmark_spsm_case", + "comprehensive_gather_test", + "comprehensive_scatter_test", + "comprehensive_spmm_test", + "comprehensive_spsm_test", + "cusparse_spmv_gather", + "cusparse_spmv_scatter", + "flagsparse_gather", + "flagsparse_sddmm_csr", + "flagsparse_spgemm_csr", + "flagsparse_spmm_coo", + "flagsparse_spmm_csr", + "flagsparse_spmm_csr_opt", + "flagsparse_spmv_coo", + "flagsparse_spmv_coo_tocsr", + "flagsparse_spmv_csr", + "flagsparse_spsm_coo", + "flagsparse_spsm_csr", + "flagsparse_spsv_coo", + "flagsparse_spsv_csr", + "prepare_sddmm_csr", + "prepare_spgemm_csr", + "prepare_spmm_csr_opt", + "prepare_spmv_coo", + "prepare_spmv_coo_tocsr", + "prepare_spmv_csr", + "pytorch_index_gather", + "pytorch_index_scatter", + "triton_cusparse_gather", + "triton_cusparse_scatter", +] diff --git a/src/flagsparse/sparse_operations/_common.py b/src/flagsparse/sparse_operations/_common.py new file mode 100644 index 0000000..6f7725d --- /dev/null +++ b/src/flagsparse/sparse_operations/_common.py @@ -0,0 +1,334 @@ +"""Shared imports, dtypes, and helpers for FlagSparse sparse ops.""" + +import time + +try: + import torch + import triton + import triton.language as tl +except ImportError as exc: + raise ImportError( + "Runtime dependencies are missing. Install them manually: pip install torch triton" + ) from exc + +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse +except ImportError: + cp = None + cpx_sparse = None + +_SUPPORTED_VALUE_DTYPES = [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, +] +SUPPORTED_VALUE_DTYPES = tuple(_SUPPORTED_VALUE_DTYPES) +SUPPORTED_INDEX_DTYPES = (torch.int32, torch.int64) +_INDEX_LIMIT_INT32 = 2**31 - 1 + +# Star-import exposes only non-underscore names unless listed here. +__all__ = ( + "SUPPORTED_VALUE_DTYPES", + "SUPPORTED_INDEX_DTYPES", + "_INDEX_LIMIT_INT32", + "_is_complex_dtype", + "_resolve_scatter_value_dtype", + "_component_dtype_for_complex", + "_tolerance_for_dtype", + "_require_cupy", + "_cupy_dtype_from_torch", + "_cupy_from_torch", + "_torch_from_cupy", + "_to_torch_tensor", + "_to_backend_like", + "_cusparse_baseline_skip_reason", + "_build_random_dense", + "_build_indices", + "_build_random_csr", + "_validate_common_inputs", + "_prepare_inputs", + "_prepare_scatter_inputs", + "_benchmark_cuda_op", + "cp", + "cpx_sparse", + "time", + "torch", + "triton", + "tl", +) + +def _is_complex_dtype(value_dtype): + return value_dtype in (torch.complex64, torch.complex128) + + +def _resolve_scatter_value_dtype(value_dtype, dtype_policy="auto"): + dtype_policy = str(dtype_policy).lower() + if dtype_policy not in ("auto", "strict"): + raise ValueError("dtype_policy must be 'auto' or 'strict'") + if isinstance(value_dtype, str): + token = value_dtype.strip().lower() + mapping = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "float64": torch.float64, + "complex64": torch.complex64, + "complex128": torch.complex128, + } + if token not in mapping: + raise TypeError(f"Unsupported dtype token: {value_dtype}") + value_dtype = mapping[token] + return value_dtype, False, None + + +def _component_dtype_for_complex(value_dtype): + if value_dtype == torch.complex64: + return torch.float32 + if value_dtype == torch.complex128: + return torch.float64 + raise TypeError(f"Unsupported complex dtype: {value_dtype}") + + +def _tolerance_for_dtype(value_dtype): + if value_dtype == torch.float16: + return 2e-3, 2e-3 + if value_dtype == torch.bfloat16: + return 1e-1, 1e-1 + if value_dtype in (torch.float32, torch.complex64): + return 1e-6, 1e-5 + if value_dtype in (torch.float64, torch.complex128): + return 1e-10, 1e-8 + return 1e-6, 1e-5 + + +def _require_cupy(): + if cp is None or cpx_sparse is None: + raise RuntimeError( + "CuPy is required for cuSPARSE baseline. " + "Install a CUDA-matched wheel, for example: pip install cupy-cuda12x" + ) + + +def _cupy_dtype_from_torch(torch_dtype): + _require_cupy() + mapping = { + torch.float16: cp.float16, + # Keep cuSPARSE baseline stable for bf16 by computing in fp32 on CuPy path. + torch.bfloat16: cp.float32, + torch.float32: cp.float32, + torch.float64: cp.float64, + torch.complex64: cp.complex64, + torch.complex128: cp.complex128, + torch.int32: cp.int32, + torch.int64: cp.int64, + } + if torch_dtype not in mapping: + raise TypeError(f"Unsupported dtype conversion to CuPy: {torch_dtype}") + return mapping[torch_dtype] + + +def _cupy_from_torch(tensor): + _require_cupy() + return cp.from_dlpack(torch.utils.dlpack.to_dlpack(tensor)) + + +def _torch_from_cupy(array): + try: + dlpack_capsule = array.toDlpack() + except AttributeError: + dlpack_capsule = array.to_dlpack() + return torch.utils.dlpack.from_dlpack(dlpack_capsule) + + +def _to_torch_tensor(x, name): + if torch.is_tensor(x): + return x, "torch" + if cp is not None and isinstance(x, cp.ndarray): + return _torch_from_cupy(x), "cupy" + raise TypeError(f"{name} must be a torch.Tensor or cupy.ndarray") + + +def _to_backend_like(torch_tensor, ref_obj): + if cp is not None and isinstance(ref_obj, cp.ndarray): + return _cupy_from_torch(torch_tensor) + return torch_tensor + + +def _cusparse_baseline_skip_reason(value_dtype): + if value_dtype == torch.bfloat16: + return "bfloat16 is not supported by the cuSPARSE baseline path; skipped" + if cp is None and value_dtype == torch.float16: + return "float16 is not supported by torch sparse fallback when CuPy is unavailable; skipped" + return None + + +def _build_random_dense(dense_size, value_dtype, device): + if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + return torch.randn(dense_size, dtype=value_dtype, device=device) + if _is_complex_dtype(value_dtype): + component_dtype = _component_dtype_for_complex(value_dtype) + real = torch.randn(dense_size, dtype=component_dtype, device=device) + imag = torch.randn(dense_size, dtype=component_dtype, device=device) + return torch.complex(real, imag) + raise TypeError(f"Unsupported value dtype: {value_dtype}") + + +def _build_indices(nnz, dense_size, index_dtype, device, unique=False): + if unique and nnz <= dense_size: + return torch.randperm(dense_size, device=device)[:nnz].to(index_dtype) + return torch.randint(0, dense_size, (nnz,), dtype=index_dtype, device=device) + + +def _build_random_csr(n_rows, n_cols, nnz, value_dtype, index_dtype, device): + if nnz <= 0 or n_rows <= 0 or n_cols <= 0: + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=device) + return ( + torch.empty(0, dtype=value_dtype, device=device), + torch.empty(0, dtype=index_dtype, device=device), + indptr, + ) + row_choices = torch.randint(0, n_rows, (nnz,), device=device) + row_choices, _ = torch.sort(row_choices) + nnz_per_row = torch.bincount(row_choices, minlength=n_rows) + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=device) + indptr[1:] = torch.cumsum(nnz_per_row, dim=0) + indices = torch.randint(0, n_cols, (nnz,), dtype=index_dtype, device=device) + data = _build_random_dense(nnz, value_dtype, device) + return data, indices, indptr + + +def _validate_common_inputs(dense_vector, indices): + if dense_vector.ndim != 1: + raise ValueError("dense_vector must be a 1D tensor") + if indices.ndim != 1: + raise ValueError("indices must be a 1D tensor") + if not dense_vector.is_cuda or not indices.is_cuda: + raise ValueError("dense_vector and indices must both be CUDA tensors") + if dense_vector.dtype not in SUPPORTED_VALUE_DTYPES: + raise TypeError( + f"dense_vector dtype must be one of: {', '.join(str(dt) for dt in SUPPORTED_VALUE_DTYPES)}" + ) + if indices.dtype not in SUPPORTED_INDEX_DTYPES: + raise TypeError("indices dtype must be torch.int32 or torch.int64") + + +def _prepare_inputs(dense_vector, indices): + _validate_common_inputs(dense_vector, indices) + + dense_vector = dense_vector.contiguous() + indices = indices.contiguous() + + max_index = -1 + if indices.numel() > 0: + if torch.any(indices < 0).item(): + raise IndexError("indices must be non-negative") + max_index = int(indices.max().item()) + if max_index >= dense_vector.numel(): + raise IndexError( + f"indices out of range: max index {max_index}, dense size {dense_vector.numel()}" + ) + + kernel_indices = indices + if indices.dtype == torch.int64: + if max_index > _INDEX_LIMIT_INT32: + raise ValueError( + f"int64 index value {max_index} exceeds Triton int32 kernel range" + ) + kernel_indices = indices.to(torch.int32) + + return dense_vector, indices, kernel_indices + + +def _prepare_scatter_inputs( + sparse_values, indices, dense_size=None, out=None, dtype_policy="auto", return_metadata=False +): + if sparse_values.ndim != 1: + raise ValueError("sparse_values must be a 1D tensor") + if indices.ndim != 1: + raise ValueError("indices must be a 1D tensor") + if sparse_values.numel() != indices.numel(): + raise ValueError("sparse_values and indices must have the same number of elements") + if not sparse_values.is_cuda or not indices.is_cuda: + raise ValueError("sparse_values and indices must both be CUDA tensors") + if sparse_values.dtype not in SUPPORTED_VALUE_DTYPES: + raise TypeError( + f"sparse_values dtype must be one of: {', '.join(str(dt) for dt in SUPPORTED_VALUE_DTYPES)}" + ) + if indices.dtype not in SUPPORTED_INDEX_DTYPES: + raise TypeError("indices dtype must be torch.int32 or torch.int64") + + requested_value_dtype = sparse_values.dtype + effective_value_dtype, fallback_applied, fallback_reason = _resolve_scatter_value_dtype( + requested_value_dtype, dtype_policy=dtype_policy + ) + if effective_value_dtype != sparse_values.dtype: + sparse_values = sparse_values.to(effective_value_dtype) + + sparse_values = sparse_values.contiguous() + indices = indices.contiguous() + + if dense_size is None: + dense_size = int(indices.max().item()) + 1 if indices.numel() > 0 else 0 + dense_size = int(dense_size) + if dense_size < 0: + raise ValueError("dense_size must be non-negative") + + max_index = -1 + if indices.numel() > 0: + if torch.any(indices < 0).item(): + raise IndexError("indices must be non-negative") + max_index = int(indices.max().item()) + if max_index >= dense_size: + raise IndexError( + f"indices out of range: max index {max_index}, dense size {dense_size}" + ) + + kernel_indices = indices + + if out is not None: + if out.ndim != 1: + raise ValueError("out must be a 1D tensor") + if not out.is_cuda: + raise ValueError("out must be a CUDA tensor") + if out.dtype != sparse_values.dtype: + raise TypeError("out dtype must match sparse_values dtype") + if out.numel() != dense_size: + raise ValueError("out size must equal dense_size") + if out.device != sparse_values.device: + raise ValueError("out must be on the same device as sparse_values") + + metadata = { + "requested_value_dtype": requested_value_dtype, + "effective_value_dtype": sparse_values.dtype, + "fallback_applied": bool(fallback_applied), + "fallback_reason": fallback_reason, + "dtype_policy": str(dtype_policy).lower(), + } + if return_metadata: + return sparse_values, indices, kernel_indices, dense_size, metadata + return sparse_values, indices, kernel_indices, dense_size + + +def _benchmark_cuda_op(op, warmup, iters): + warmup = max(0, int(warmup)) + iters = max(1, int(iters)) + + output = None + for _ in range(warmup): + output = op() + + torch.cuda.synchronize() + if cp is not None: + cp.cuda.runtime.deviceSynchronize() + start_time = time.perf_counter() + for _ in range(iters): + output = op() + torch.cuda.synchronize() + if cp is not None: + cp.cuda.runtime.deviceSynchronize() + elapsed_ms = (time.perf_counter() - start_time) * 1000.0 / iters + return output, elapsed_ms diff --git a/src/flagsparse/sparse_operations/benchmarks.py b/src/flagsparse/sparse_operations/benchmarks.py new file mode 100644 index 0000000..395735b --- /dev/null +++ b/src/flagsparse/sparse_operations/benchmarks.py @@ -0,0 +1,558 @@ +"""Benchmarks for gather, scatter, SpMV, SpMM, SpGEMM, SDDMM, and SpSM.""" + +from ._common import * + +from .gather_scatter import ( + _cusparse_spmv, + _make_gather_selector_matrix, + _make_scatter_selector_matrix, + _pytorch_scatter_impl, + _triton_gather_impl, + _triton_scatter_impl, +) +from .spmv_csr import flagsparse_spmv_csr, prepare_spmv_csr +from .spmm_csr import ( + benchmark_spmm_case, + benchmark_spmm_opt_case, + comprehensive_spmm_test, +) +from .spgemm_csr import benchmark_spgemm_case +from .sddmm_csr import benchmark_sddmm_case +from .spsm import benchmark_spsm_case + + +def _normalize_dtype_name(value): + if isinstance(value, str): + return value.strip().lower() + return str(value).replace("torch.", "") + + +def _resolve_scatter_benchmark_dtype(value_dtype, dtype_policy): + requested_name = _normalize_dtype_name(value_dtype) + effective_dtype, fallback_applied, fallback_reason = _resolve_scatter_value_dtype( + value_dtype, dtype_policy=dtype_policy + ) + return requested_name, effective_dtype, bool(fallback_applied), fallback_reason + + +def benchmark_gather_case( + dense_size=65536, + nnz=4096, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + block_size=1024, + run_cusparse=True, +): + """Benchmark Triton vs PyTorch indexing vs cuSPARSE-backed COO SpMV.""" + device = torch.device("cuda") + dense_vector = _build_random_dense(dense_size, value_dtype, device) + indices = _build_indices(nnz, dense_size, index_dtype, device, unique=False) + + dense_vector, indices, kernel_indices = _prepare_inputs(dense_vector, indices) + expected = dense_vector[indices] + + pytorch_op = lambda: dense_vector[indices] + triton_op = lambda: _triton_gather_impl(dense_vector, kernel_indices, block_size=block_size) + + pytorch_values, pytorch_ms = _benchmark_cuda_op(pytorch_op, warmup=warmup, iters=iters) + triton_values, triton_ms = _benchmark_cuda_op(triton_op, warmup=warmup, iters=iters) + + atol, rtol = _tolerance_for_dtype(value_dtype) + triton_match = torch.allclose(triton_values, expected, atol=atol, rtol=rtol) + triton_max_error = ( + float(torch.max(torch.abs(triton_values - expected)).item()) + if nnz > 0 + else 0.0 + ) + + cusparse_ms = None + cusparse_match = None + cusparse_max_error = None + cusparse_reason = None + if run_cusparse: + skip_reason = _cusparse_baseline_skip_reason(value_dtype) + if skip_reason: + cusparse_reason = skip_reason + else: + try: + selector_matrix = _make_gather_selector_matrix( + indices, dense_vector.numel(), dense_vector.dtype + ) + cusparse_op = lambda: _cusparse_spmv(selector_matrix, dense_vector) + cusparse_values, cusparse_ms = _benchmark_cuda_op( + cusparse_op, warmup=warmup, iters=iters + ) + cusparse_match = torch.allclose( + cusparse_values, expected, atol=atol, rtol=rtol + ) + cusparse_max_error = ( + float(torch.max(torch.abs(cusparse_values - expected)).item()) + if nnz > 0 + else 0.0 + ) + except Exception as exc: + cusparse_reason = str(exc) + + triton_speedup_vs_pytorch = ( + pytorch_ms / triton_ms if triton_ms > 0 else float("inf") + ) + triton_speedup_vs_cusparse = ( + cusparse_ms / triton_ms + if (cusparse_ms is not None and triton_ms > 0) + else None + ) + + return { + "parameters": { + "dense_size": dense_size, + "nnz": nnz, + "value_dtype": str(value_dtype), + "index_dtype": str(index_dtype), + "warmup": warmup, + "iters": iters, + }, + "performance": { + "pytorch_ms": pytorch_ms, + "triton_ms": triton_ms, + "cusparse_ms": cusparse_ms, + "triton_speedup_vs_pytorch": triton_speedup_vs_pytorch, + "triton_speedup_vs_cusparse": triton_speedup_vs_cusparse, + }, + "verification": { + "triton_match_pytorch": triton_match, + "triton_max_error": triton_max_error, + "cusparse_match_pytorch": cusparse_match, + "cusparse_max_error": cusparse_max_error, + }, + "backend_status": { + "cusparse_unavailable_reason": cusparse_reason, + }, + "samples": { + "pytorch": pytorch_values, + "triton": triton_values, + }, + } + + +def benchmark_scatter_case( + dense_size=65536, + nnz=4096, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + block_size=1024, + run_cusparse=True, + unique_indices=True, + reset_output=True, + dtype_policy="auto", + index_fallback_policy="auto", +): + """Benchmark Triton scatter vs PyTorch index_copy vs cuSPARSE-backed COO SpMV.""" + device = torch.device("cuda") + requested_value_dtype, requested_effective_dtype, fallback_applied, fallback_reason = ( + _resolve_scatter_benchmark_dtype(value_dtype, dtype_policy) + ) + sparse_values = _build_random_dense(nnz, requested_effective_dtype, device) + indices = _build_indices(nnz, dense_size, index_dtype, device, unique=unique_indices) + + sparse_values, indices, kernel_indices, dense_size, prep_meta = _prepare_scatter_inputs( + sparse_values, + indices, + dense_size=dense_size, + out=None, + dtype_policy=dtype_policy, + return_metadata=True, + ) + effective_value_dtype = prep_meta["effective_value_dtype"] + fallback_applied = fallback_applied or prep_meta["fallback_applied"] + fallback_reason = fallback_reason or prep_meta["fallback_reason"] + + base_out = _build_random_dense(dense_size, sparse_values.dtype, device) + expected = _pytorch_scatter_impl( + sparse_values, + indices, + dense_size, + out=base_out.clone(), + reset_output=reset_output, + ) + + pytorch_out = base_out.clone() + triton_probe_out = base_out.clone() + _, triton_index_meta = _triton_scatter_impl( + sparse_values, + kernel_indices, + dense_size=dense_size, + out=triton_probe_out, + block_size=block_size, + reset_output=reset_output, + index_fallback_policy=index_fallback_policy, + return_metadata=True, + ) + effective_kernel_indices = kernel_indices + if triton_index_meta["index_fallback_applied"] and kernel_indices.dtype == torch.int64: + effective_kernel_indices = kernel_indices.to(torch.int32) + triton_out = base_out.clone() + + pytorch_op = lambda: _pytorch_scatter_impl( + sparse_values, + indices, + dense_size, + out=pytorch_out, + reset_output=reset_output, + ) + triton_op = lambda: _triton_scatter_impl( + sparse_values, + effective_kernel_indices, + dense_size=dense_size, + out=triton_out, + block_size=block_size, + reset_output=reset_output, + index_fallback_policy="strict", + ) + + pytorch_values, pytorch_ms = _benchmark_cuda_op(pytorch_op, warmup=warmup, iters=iters) + triton_values, triton_ms = _benchmark_cuda_op(triton_op, warmup=warmup, iters=iters) + + atol, rtol = _tolerance_for_dtype(effective_value_dtype) + triton_match = torch.allclose(triton_values, expected, atol=atol, rtol=rtol) + triton_max_error = ( + float(torch.max(torch.abs(triton_values - expected)).item()) + if dense_size > 0 + else 0.0 + ) + + cusparse_ms = None + cusparse_match = None + cusparse_max_error = None + cusparse_reason = None + if run_cusparse: + if not reset_output: + skip_reason = "cuSPARSE scatter baseline only matches reset_output=True semantics" + else: + skip_reason = _cusparse_baseline_skip_reason(sparse_values.dtype) + if skip_reason: + cusparse_reason = skip_reason + else: + try: + selector_matrix = _make_scatter_selector_matrix( + indices, dense_size, sparse_values.dtype + ) + cusparse_op = lambda: _cusparse_spmv(selector_matrix, sparse_values) + cusparse_values, cusparse_ms = _benchmark_cuda_op( + cusparse_op, warmup=warmup, iters=iters + ) + cusparse_match = torch.allclose( + cusparse_values, expected, atol=atol, rtol=rtol + ) + cusparse_max_error = ( + float(torch.max(torch.abs(cusparse_values - expected)).item()) + if dense_size > 0 + else 0.0 + ) + except Exception as exc: + cusparse_reason = str(exc) + + triton_speedup_vs_pytorch = ( + pytorch_ms / triton_ms if triton_ms > 0 else float("inf") + ) + triton_speedup_vs_cusparse = ( + cusparse_ms / triton_ms + if (cusparse_ms is not None and triton_ms > 0) + else None + ) + + return { + "parameters": { + "dense_size": dense_size, + "nnz": nnz, + "value_dtype": requested_value_dtype, + "effective_value_dtype": str(effective_value_dtype), + "index_dtype": str(index_dtype), + "warmup": warmup, + "iters": iters, + "unique_indices": unique_indices, + "reset_output": bool(reset_output), + "dtype_policy": str(dtype_policy).lower(), + "fallback_applied": bool(fallback_applied), + "index_fallback_policy": str(index_fallback_policy).lower(), + "kernel_index_dtype": triton_index_meta["kernel_index_dtype"], + }, + "performance": { + "pytorch_ms": pytorch_ms, + "triton_ms": triton_ms, + "cusparse_ms": cusparse_ms, + "triton_speedup_vs_pytorch": triton_speedup_vs_pytorch, + "triton_speedup_vs_cusparse": triton_speedup_vs_cusparse, + }, + "verification": { + "triton_match_pytorch": triton_match, + "triton_max_error": triton_max_error, + "cusparse_match_pytorch": cusparse_match, + "cusparse_max_error": cusparse_max_error, + }, + "backend_status": { + "cusparse_unavailable_reason": cusparse_reason, + "fallback_reason": fallback_reason, + "index_fallback_applied": bool(triton_index_meta["index_fallback_applied"]), + "index_fallback_reason": triton_index_meta["index_fallback_reason"], + }, + "samples": { + "pytorch": pytorch_values, + "triton": triton_values, + }, + } + + +def benchmark_spmv_case( + n_rows=4096, + n_cols=4096, + nnz=65536, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + block_nnz=256, + max_segments=None, + run_cusparse=True, +): + """Benchmark Triton CSR SpMV vs cuSPARSE (CuPy CSR @ x).""" + device = torch.device("cuda") + data, indices, indptr = _build_random_csr( + n_rows, n_cols, nnz, value_dtype, index_dtype, device + ) + x = _build_random_dense(n_cols, value_dtype, device) + shape = (n_rows, n_cols) + prepared = prepare_spmv_csr( + data, + indices, + indptr, + shape, + block_nnz=block_nnz, + max_segments=max_segments, + ) + triton_op = lambda: flagsparse_spmv_csr( + x=x, + prepared=prepared, + return_time=False, + ) + triton_y, triton_ms = _benchmark_cuda_op(triton_op, warmup=warmup, iters=iters) + _cupy_supported_dtypes = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ) + if ( + cp is not None + and cpx_sparse is not None + and value_dtype in _cupy_supported_dtypes + ): + data_cp = _cupy_from_torch(data) + indices_cp = _cupy_from_torch(indices.to(torch.int64)) + indptr_cp = _cupy_from_torch(indptr) + x_cp = _cupy_from_torch(x) + A_csr = cpx_sparse.csr_matrix( + (data_cp, indices_cp, indptr_cp), shape=shape + ) + ref_y = A_csr @ x_cp + expected = _torch_from_cupy(ref_y) + else: + row_indices = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + col_ind = indices.to(torch.int64) + coo = torch.sparse_coo_tensor( + torch.stack([row_indices, col_ind]), + data, + shape, + device=device, + ).coalesce() + x_2d = x.unsqueeze(1) + if value_dtype in (torch.float16, torch.bfloat16): + coo_f32 = coo.to(torch.float32) + x_2d_f32 = x_2d.to(torch.float32) + expected = torch.sparse.mm(coo_f32, x_2d_f32).squeeze(1).to(value_dtype) + else: + expected = torch.sparse.mm(coo, x_2d).squeeze(1) + atol, rtol = _tolerance_for_dtype(value_dtype) + triton_match = torch.allclose(triton_y, expected, atol=atol, rtol=rtol) + triton_max_error = ( + float(torch.max(torch.abs(triton_y - expected)).item()) + if n_rows > 0 + else 0.0 + ) + cusparse_ms = None + cusparse_match = None + cusparse_max_error = None + cusparse_reason = None + if ( + run_cusparse + and cp is not None + and cpx_sparse is not None + and value_dtype in _cupy_supported_dtypes + ): + skip_reason = _cusparse_baseline_skip_reason(value_dtype) + if skip_reason: + cusparse_reason = skip_reason + else: + try: + cusparse_op = lambda: _torch_from_cupy( + A_csr @ _cupy_from_torch(x) + ) + cusparse_values, cusparse_ms = _benchmark_cuda_op( + cusparse_op, warmup=warmup, iters=iters + ) + cusparse_match = torch.allclose( + cusparse_values, expected, atol=atol, rtol=rtol + ) + cusparse_max_error = ( + float(torch.max(torch.abs(cusparse_values - expected)).item()) + if n_rows > 0 + else 0.0 + ) + except Exception as exc: + cusparse_reason = str(exc) + elif run_cusparse and value_dtype not in _cupy_supported_dtypes: + cusparse_reason = ( + "float16/bfloat16 not supported by CuPy sparse; skipped" + ) + triton_speedup_vs_cusparse = ( + cusparse_ms / triton_ms + if (cusparse_ms is not None and triton_ms > 0) + else None + ) + return { + "parameters": { + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": nnz, + "value_dtype": str(value_dtype), + "index_dtype": str(index_dtype), + "warmup": warmup, + "iters": iters, + }, + "performance": { + "triton_ms": triton_ms, + "cusparse_ms": cusparse_ms, + "triton_speedup_vs_cusparse": triton_speedup_vs_cusparse, + }, + "verification": { + "triton_match_reference": triton_match, + "triton_max_error": triton_max_error, + "cusparse_match_reference": cusparse_match, + "cusparse_max_error": cusparse_max_error, + }, + "backend_status": { + "cusparse_unavailable_reason": cusparse_reason, + }, + "samples": {"triton": triton_y, "reference": expected}, + } + + +def benchmark_performance( + dense_size=65536, + nnz=4096, + dtype=torch.float32, + index_dtype=torch.int32, +): + """Backward-compatible benchmark entry.""" + result = benchmark_gather_case( + dense_size=dense_size, + nnz=nnz, + value_dtype=dtype, + index_dtype=index_dtype, + warmup=10, + iters=100, + run_cusparse=False, + ) + return { + "triton_time_ms": result["performance"]["triton_ms"], + "results_match": result["verification"]["triton_match_pytorch"], + "dtype": str(dtype), + "index_dtype": str(index_dtype), + "dense_size": dense_size, + "nnz": nnz, + } + + +def comprehensive_gather_test( + dense_size=100000, + nnz=10000, + dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + run_cusparse=True, +): + """Full test entry for one configuration.""" + return benchmark_gather_case( + dense_size=dense_size, + nnz=nnz, + value_dtype=dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + ) + + +def comprehensive_scatter_test( + dense_size=100000, + nnz=10000, + dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + run_cusparse=True, + unique_indices=True, + reset_output=True, + dtype_policy="auto", + index_fallback_policy="auto", +): + """Full scatter test entry for one configuration.""" + return benchmark_scatter_case( + dense_size=dense_size, + nnz=nnz, + value_dtype=dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + unique_indices=unique_indices, + reset_output=reset_output, + dtype_policy=dtype_policy, + index_fallback_policy=index_fallback_policy, + ) + + +def comprehensive_spsm_test( + fmt="csr", + n_rows=1024, + n_rhs=32, + nnz=8192, + dtype=torch.float32, + index_dtype=torch.int32, + alpha=1.0, + lower=True, + unit_diagonal=False, + warmup=10, + iters=50, +): + """Full SpSM benchmark entry for one configuration.""" + return benchmark_spsm_case( + fmt=fmt, + n_rows=n_rows, + n_rhs=n_rhs, + nnz=nnz, + value_dtype=dtype, + index_dtype=index_dtype, + alpha=alpha, + lower=lower, + unit_diagonal=unit_diagonal, + warmup=warmup, + iters=iters, + ) diff --git a/src/flagsparse/sparse_operations/gather_scatter.py b/src/flagsparse/sparse_operations/gather_scatter.py new file mode 100644 index 0000000..ab417a6 --- /dev/null +++ b/src/flagsparse/sparse_operations/gather_scatter.py @@ -0,0 +1,541 @@ +"""Gather and scatter (Triton kernels + cuSPARSE-style baselines).""" + +from ._common import * + +import triton +import triton.language as tl + +@triton.jit +def _gather_real_kernel( + sparse_values_ptr, + dense_values_ptr, + indices_ptr, + nnz, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < nnz + indices = tl.load(indices_ptr + offsets, mask=mask, other=0) + gathered_values = tl.load(dense_values_ptr + indices, mask=mask, other=0.0) + tl.store(sparse_values_ptr + offsets, gathered_values, mask=mask) + + +@triton.jit +def _gather_complex_kernel( + sparse_values_ri_ptr, + dense_values_ri_ptr, + indices_ptr, + nnz, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < nnz + indices = tl.load(indices_ptr + offsets, mask=mask, other=0) + + dense_offsets = indices * 2 + sparse_offsets = offsets * 2 + + gathered_real = tl.load(dense_values_ri_ptr + dense_offsets, mask=mask, other=0.0) + gathered_imag = tl.load(dense_values_ri_ptr + dense_offsets + 1, mask=mask, other=0.0) + + tl.store(sparse_values_ri_ptr + sparse_offsets, gathered_real, mask=mask) + tl.store(sparse_values_ri_ptr + sparse_offsets + 1, gathered_imag, mask=mask) + + +@triton.jit +def _scatter_real_kernel( + dense_values_ptr, + sparse_values_ptr, + indices_ptr, + nnz, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < nnz + + indices = tl.load(indices_ptr + offsets, mask=mask, other=0) + values = tl.load(sparse_values_ptr + offsets, mask=mask, other=0.0) + tl.store(dense_values_ptr + indices, values, mask=mask) + + +@triton.jit +def _scatter_complex_kernel( + dense_values_ri_ptr, + sparse_values_ri_ptr, + indices_ptr, + nnz, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < nnz + + indices = tl.load(indices_ptr + offsets, mask=mask, other=0) + dense_offsets = indices * 2 + sparse_offsets = offsets * 2 + + values_real = tl.load(sparse_values_ri_ptr + sparse_offsets, mask=mask, other=0.0) + values_imag = tl.load(sparse_values_ri_ptr + sparse_offsets + 1, mask=mask, other=0.0) + + tl.store(dense_values_ri_ptr + dense_offsets, values_real, mask=mask) + tl.store(dense_values_ri_ptr + dense_offsets + 1, values_imag, mask=mask) + + +def _triton_gather_impl(dense_vector, kernel_indices, block_size=1024): + nnz = kernel_indices.numel() + if nnz == 0: + return torch.empty(0, dtype=dense_vector.dtype, device=dense_vector.device) + + grid = lambda meta: (triton.cdiv(nnz, meta["BLOCK_SIZE"]),) + + if not _is_complex_dtype(dense_vector.dtype): + sparse_values = torch.empty(nnz, dtype=dense_vector.dtype, device=dense_vector.device) + _gather_real_kernel[grid]( + sparse_values, + dense_vector, + kernel_indices, + nnz, + BLOCK_SIZE=block_size, + ) + return sparse_values + + sparse_values = torch.empty(nnz, dtype=dense_vector.dtype, device=dense_vector.device) + dense_values_ri = torch.view_as_real(dense_vector).reshape(-1) + sparse_values_ri = torch.view_as_real(sparse_values).reshape(-1) + + _gather_complex_kernel[grid]( + sparse_values_ri, + dense_values_ri, + kernel_indices, + nnz, + BLOCK_SIZE=block_size, + ) + return sparse_values + + +def _triton_scatter_impl( + sparse_values, + kernel_indices, + dense_size, + out=None, + block_size=1024, + reset_output=True, + index_fallback_policy="auto", + return_metadata=False, +): + index_fallback_policy = str(index_fallback_policy).lower() + if index_fallback_policy not in ("auto", "strict"): + raise ValueError("index_fallback_policy must be 'auto' or 'strict'") + + if out is None: + dense_values = torch.zeros( + dense_size, dtype=sparse_values.dtype, device=sparse_values.device + ) + else: + dense_values = out + if reset_output: + dense_values.zero_() + + nnz = kernel_indices.numel() + scatter_meta = { + "index_fallback_applied": False, + "index_fallback_reason": None, + "kernel_index_dtype": str(kernel_indices.dtype).replace("torch.", ""), + } + if nnz == 0: + if return_metadata: + return dense_values, scatter_meta + return dense_values + + try: + _launch_triton_scatter_kernel( + dense_values, + sparse_values, + kernel_indices, + nnz, + block_size=block_size, + ) + except Exception as exc: + if kernel_indices.dtype != torch.int64 or index_fallback_policy != "auto": + raise RuntimeError( + f"Triton scatter failed for index dtype {kernel_indices.dtype}: " + f"{exc.__class__.__name__}: {str(exc)}" + ) from exc + + max_index = int(kernel_indices.max().item()) if nnz > 0 else -1 + if max_index > _INDEX_LIMIT_INT32: + raise RuntimeError( + "Triton scatter failed for int64 indices, and int32 fallback is invalid: " + f"max index {max_index} exceeds int32 range" + ) from exc + + fallback_indices = kernel_indices.to(torch.int32) + try: + _launch_triton_scatter_kernel( + dense_values, + sparse_values, + fallback_indices, + nnz, + block_size=block_size, + ) + except Exception as fallback_exc: + raise RuntimeError( + "Triton scatter failed for int64 indices, and int32 fallback also failed: " + f"{fallback_exc.__class__.__name__}: {str(fallback_exc)}" + ) from fallback_exc + + scatter_meta["index_fallback_applied"] = True + scatter_meta["index_fallback_reason"] = ( + f"int64 kernel launch failed: {exc.__class__.__name__}: {str(exc)}" + ) + scatter_meta["kernel_index_dtype"] = "int32" + + if return_metadata: + return dense_values, scatter_meta + return dense_values + + +def _launch_triton_scatter_kernel( + dense_values, sparse_values, kernel_indices, nnz, block_size=1024 +): + grid = lambda meta: (triton.cdiv(nnz, meta["BLOCK_SIZE"]),) + if not _is_complex_dtype(sparse_values.dtype): + _scatter_real_kernel[grid]( + dense_values, + sparse_values, + kernel_indices, + nnz, + BLOCK_SIZE=block_size, + ) + return + dense_values_ri = torch.view_as_real(dense_values).reshape(-1) + sparse_values_ri = torch.view_as_real(sparse_values).reshape(-1) + _scatter_complex_kernel[grid]( + dense_values_ri, + sparse_values_ri, + kernel_indices, + nnz, + BLOCK_SIZE=block_size, + ) + + +def _validate_gather_value_dtype(dense_vector, op_name): + return None + + +def _cusparse_spmv(selector_matrix, dense_vector): + if cp is not None and cpx_sparse is not None and isinstance(selector_matrix, cpx_sparse.spmatrix): + if torch.is_tensor(dense_vector): + out_dtype = dense_vector.dtype + dense_for_compute = ( + dense_vector.to(torch.float32) + if dense_vector.dtype == torch.bfloat16 + else dense_vector + ) + dense_cp = _cupy_from_torch(dense_for_compute) + out_cp = selector_matrix @ dense_cp + out_torch = _torch_from_cupy(out_cp) + if out_dtype == torch.bfloat16: + out_torch = out_torch.to(torch.bfloat16) + return out_torch + + if cp is not None and isinstance(dense_vector, cp.ndarray): + return selector_matrix @ dense_vector + + raise TypeError("dense_vector must be torch.Tensor or cupy.ndarray") + + # Fallback path: torch sparse SpMV (still CUDA-backed). + if torch.is_tensor(selector_matrix) and selector_matrix.is_sparse: + if not torch.is_tensor(dense_vector): + raise TypeError("dense_vector must be torch.Tensor for torch sparse fallback") + out_dtype = dense_vector.dtype + dense_for_compute = ( + dense_vector.to(torch.float32) + if dense_vector.dtype == torch.bfloat16 + else dense_vector + ) + out = torch.sparse.mm(selector_matrix, dense_for_compute.unsqueeze(1)).squeeze(1) + if out_dtype == torch.bfloat16: + out = out.to(torch.bfloat16) + return out + + if cp is None or cpx_sparse is None: + raise RuntimeError( + "CuPy is not available and torch sparse fallback selector is not provided" + ) + raise TypeError( + "selector_matrix must be a cupyx sparse matrix or torch sparse tensor" + ) + + +def _make_gather_selector_matrix(indices, dense_size, value_dtype): + if cp is not None and cpx_sparse is not None: + rows_cp = cp.arange(indices.numel(), dtype=cp.int64) + cols_cp = _cupy_from_torch(indices.to(torch.int64)) + vals_cp = cp.ones(indices.numel(), dtype=_cupy_dtype_from_torch(value_dtype)) + return cpx_sparse.coo_matrix( + (vals_cp, (rows_cp, cols_cp)), + shape=(indices.numel(), dense_size), + ) + + rows = torch.arange(indices.numel(), dtype=torch.int64, device=indices.device) + cols = indices.to(torch.int64) + coords = torch.stack([rows, cols], dim=0) + values = torch.ones(indices.numel(), dtype=value_dtype, device=indices.device) + return torch.sparse_coo_tensor( + coords, values, size=(indices.numel(), dense_size), device=indices.device + ).coalesce() + + +def _make_scatter_selector_matrix(indices, dense_size, value_dtype): + if cp is not None and cpx_sparse is not None: + rows_cp = _cupy_from_torch(indices.to(torch.int64)) + cols_cp = cp.arange(indices.numel(), dtype=cp.int64) + vals_cp = cp.ones(indices.numel(), dtype=_cupy_dtype_from_torch(value_dtype)) + return cpx_sparse.coo_matrix( + (vals_cp, (rows_cp, cols_cp)), + shape=(dense_size, indices.numel()), + ) + + rows = indices.to(torch.int64) + cols = torch.arange(indices.numel(), dtype=torch.int64, device=indices.device) + coords = torch.stack([rows, cols], dim=0) + values = torch.ones(indices.numel(), dtype=value_dtype, device=indices.device) + return torch.sparse_coo_tensor( + coords, values, size=(dense_size, indices.numel()), device=indices.device + ).coalesce() + + +def _pytorch_scatter_impl(sparse_values, indices, dense_size, out=None, reset_output=True): + if out is None: + dense_values = torch.zeros( + dense_size, dtype=sparse_values.dtype, device=sparse_values.device + ) + else: + dense_values = out + if reset_output: + dense_values.zero_() + dense_values.index_copy_(0, indices.to(torch.int64), sparse_values) + return dense_values + + +def flagsparse_gather(a, indices, out=None, mode="raise", block_size=1024, return_time=False): + """CuPy-style gather (take): out = a[indices].""" + if mode != "raise": + raise NotImplementedError("Only mode='raise' is currently supported") + + dense_vector, dense_backend = _to_torch_tensor(a, "a") + indices_tensor, _ = _to_torch_tensor(indices, "indices") + dense_vector, indices_tensor, kernel_indices = _prepare_inputs(dense_vector, indices_tensor) + _validate_gather_value_dtype(dense_vector, "flagsparse_gather") + + torch.cuda.synchronize() + start_time = time.perf_counter() + sparse_values = _triton_gather_impl(dense_vector, kernel_indices, block_size=block_size) + torch.cuda.synchronize() + execution_time_ms = (time.perf_counter() - start_time) * 1000.0 + + if out is not None: + out_tensor, _ = _to_torch_tensor(out, "out") + if out_tensor.shape != sparse_values.shape: + raise ValueError("out shape must match gather output shape") + if out_tensor.dtype != sparse_values.dtype: + raise TypeError("out dtype must match gather output dtype") + out_tensor.copy_(sparse_values) + result = out if dense_backend == "cupy" else out_tensor + else: + result = _to_backend_like(sparse_values, a) + + if return_time: + return result, execution_time_ms + return result + + +def flagsparse_scatter( + a, + indices, + values, + mode="raise", + block_size=1024, + return_time=False, + reset_output=True, + dtype_policy="auto", + index_fallback_policy="auto", +): + """CuPy-style scatter (put): a[indices] = values (in-place).""" + if mode != "raise": + raise NotImplementedError("Only mode='raise' is currently supported") + + dense_tensor, dense_backend = _to_torch_tensor(a, "a") + values_tensor, _ = _to_torch_tensor(values, "values") + indices_tensor, _ = _to_torch_tensor(indices, "indices") + values_tensor, _, kernel_indices, dense_size, _ = _prepare_scatter_inputs( + values_tensor, + indices_tensor, + dense_size=dense_tensor.numel(), + out=dense_tensor, + dtype_policy=dtype_policy, + return_metadata=True, + ) + + torch.cuda.synchronize() + start_time = time.perf_counter() + _ = _triton_scatter_impl( + values_tensor, + kernel_indices, + dense_size=dense_size, + out=dense_tensor, + block_size=block_size, + reset_output=reset_output, + index_fallback_policy=index_fallback_policy, + ) + torch.cuda.synchronize() + execution_time_ms = (time.perf_counter() - start_time) * 1000.0 + + if dense_backend == "cupy": + # DLPack view updates dense_tensor and cupy array shares memory. + pass + + if return_time: + return execution_time_ms + return None + + +# Backward compatibility wrappers. +def triton_cusparse_gather(dense_vector, indices, block_size=1024): + return flagsparse_gather( + dense_vector, indices, block_size=block_size, return_time=True + ) + + +def triton_cusparse_scatter( + sparse_values, + indices, + dense_size=None, + out=None, + block_size=1024, + reset_output=True, + dtype_policy="auto", + index_fallback_policy="auto", +): + sparse_values_t, sparse_backend = _to_torch_tensor(sparse_values, "sparse_values") + indices_t, _ = _to_torch_tensor(indices, "indices") + if out is None: + if dense_size is None: + dense_size = int(indices_t.max().item()) + 1 if indices_t.numel() > 0 else 0 + out = torch.zeros( + int(dense_size), dtype=sparse_values_t.dtype, device=sparse_values_t.device + ) + elapsed_ms = flagsparse_scatter( + out, + indices_t, + sparse_values_t, + block_size=block_size, + return_time=True, + reset_output=reset_output, + dtype_policy=dtype_policy, + index_fallback_policy=index_fallback_policy, + ) + if sparse_backend == "cupy": + return _to_backend_like(out, sparse_values), elapsed_ms + return out, elapsed_ms + + +def pytorch_index_gather(dense_vector, indices): + """Baseline gather using PyTorch native indexing.""" + dense_vector, indices, _ = _prepare_inputs(dense_vector, indices) + torch.cuda.synchronize() + start_time = time.perf_counter() + sparse_values = dense_vector[indices] + torch.cuda.synchronize() + execution_time_ms = (time.perf_counter() - start_time) * 1000.0 + return sparse_values, execution_time_ms + + +def pytorch_index_scatter( + sparse_values, indices, dense_size=None, out=None, reset_output=True, dtype_policy="auto" +): + """Baseline scatter using PyTorch index_copy_.""" + sparse_values, indices, _, dense_size, _ = _prepare_scatter_inputs( + sparse_values, + indices, + dense_size=dense_size, + out=out, + dtype_policy=dtype_policy, + return_metadata=True, + ) + torch.cuda.synchronize() + start_time = time.perf_counter() + dense_values = _pytorch_scatter_impl( + sparse_values, indices, dense_size, out=out, reset_output=reset_output + ) + torch.cuda.synchronize() + execution_time_ms = (time.perf_counter() - start_time) * 1000.0 + return dense_values, execution_time_ms + + +def cusparse_spmv_gather(dense_vector, indices, selector_matrix=None): + """Equivalent gather baseline via cuSPARSE-backed COO SpMV.""" + dense_vector, indices, _ = _prepare_inputs(dense_vector, indices) + _validate_gather_value_dtype(dense_vector, "cusparse_spmv_gather") + skip_reason = _cusparse_baseline_skip_reason(dense_vector.dtype) + if skip_reason: + raise RuntimeError(skip_reason) + + if selector_matrix is None: + selector_matrix = _make_gather_selector_matrix( + indices, dense_vector.numel(), dense_vector.dtype + ) + + try: + torch.cuda.synchronize() + start_time = time.perf_counter() + sparse_values = _cusparse_spmv(selector_matrix, dense_vector) + torch.cuda.synchronize() + execution_time_ms = (time.perf_counter() - start_time) * 1000.0 + except Exception as exc: + raise RuntimeError( + "cuSPARSE gather baseline is unavailable in this PyTorch/CUDA environment" + ) from exc + + return sparse_values, execution_time_ms, selector_matrix + + +def cusparse_spmv_scatter( + sparse_values, indices, dense_size=None, selector_matrix=None, dtype_policy="auto" +): + """Equivalent scatter baseline via cuSPARSE-backed COO SpMV.""" + sparse_values, indices, _, dense_size, _ = _prepare_scatter_inputs( + sparse_values, + indices, + dense_size=dense_size, + out=None, + dtype_policy=dtype_policy, + return_metadata=True, + ) + skip_reason = _cusparse_baseline_skip_reason(sparse_values.dtype) + if skip_reason: + raise RuntimeError(skip_reason) + + if selector_matrix is None: + selector_matrix = _make_scatter_selector_matrix(indices, dense_size, sparse_values.dtype) + + try: + torch.cuda.synchronize() + start_time = time.perf_counter() + dense_values = _cusparse_spmv(selector_matrix, sparse_values) + torch.cuda.synchronize() + execution_time_ms = (time.perf_counter() - start_time) * 1000.0 + except Exception as exc: + raise RuntimeError( + "cuSPARSE scatter baseline is unavailable in this PyTorch/CUDA environment" + ) from exc + + return dense_values, execution_time_ms, selector_matrix diff --git a/src/flagsparse/sparse_operations/sddmm_csr.py b/src/flagsparse/sparse_operations/sddmm_csr.py new file mode 100644 index 0000000..5a9dc35 --- /dev/null +++ b/src/flagsparse/sparse_operations/sddmm_csr.py @@ -0,0 +1,572 @@ +"""CSR SDDMM kernels and helpers.""" + +from ._common import * + +SUPPORTED_SDDMM_VALUE_DTYPES = (torch.float32, torch.float64) +SUPPORTED_SDDMM_DIAGNOSTIC_VARIANTS = ("baseline", "acc64", "acc64_out64", "altreduce") + + +class SDDMMPrepared: + """Prepared CSR pattern metadata for SDDMM.""" + + __slots__ = ( + "indices", + "indptr", + "shape", + "n_rows", + "n_cols", + "nnz", + "row_ids", + "block_k", + "num_warps", + ) + + def __init__(self, indices, indptr, shape, row_ids, block_k, num_warps): + self.indices = indices + self.indptr = indptr + self.shape = (int(shape[0]), int(shape[1])) + self.n_rows = self.shape[0] + self.n_cols = self.shape[1] + self.nnz = int(indices.numel()) + self.row_ids = row_ids + self.block_k = int(block_k) + self.num_warps = int(num_warps) + + +def _resolve_sddmm_launch_config(k): + if k <= 32: + return 32, 2 + if k <= 64: + return 64, 4 + if k <= 128: + return 64, 4 + return 128, 8 + + +def _prepare_sddmm_csr_pattern(indices, indptr, shape): + if len(shape) != 2: + raise ValueError("shape must be a 2-tuple") + if indices.ndim != 1 or indptr.ndim != 1: + raise ValueError("indices and indptr must be 1D tensors") + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows < 0 or n_cols < 0: + raise ValueError("shape dimensions must be non-negative") + if indptr.numel() != n_rows + 1: + raise ValueError( + f"indptr length must be n_rows+1={n_rows + 1}, got {indptr.numel()}" + ) + if not indices.is_cuda or not indptr.is_cuda: + raise ValueError("indices and indptr must be CUDA tensors") + if indices.dtype != torch.int32: + raise TypeError("indices dtype must be torch.int32") + if indptr.dtype not in (torch.int32, torch.int64): + raise TypeError("indptr dtype must be torch.int32 or torch.int64") + + indptr64 = indptr.to(torch.int64).contiguous() + indices = indices.contiguous() + nnz = int(indices.numel()) + if indptr64.numel() > 0 and int(indptr64[0].item()) != 0: + raise ValueError("indptr[0] must be 0") + if indptr64.numel() > 0 and int(indptr64[-1].item()) != nnz: + raise ValueError(f"indptr[-1] must equal nnz={nnz}") + if indptr64.numel() > 1 and bool(torch.any(indptr64[1:] < indptr64[:-1]).item()): + raise ValueError("indptr must be nondecreasing") + if nnz > 0: + min_col = int(indices.min().item()) + max_col = int(indices.max().item()) + if min_col < 0 or max_col >= n_cols: + raise IndexError("indices out of range for shape[1]") + return indices, indptr64, (n_rows, n_cols) + + +def _build_row_ids(indptr): + n_rows = int(indptr.numel()) - 1 + if n_rows <= 0: + return torch.empty(0, dtype=torch.int32, device=indptr.device) + row_counts = indptr[1:] - indptr[:-1] + return torch.repeat_interleave( + torch.arange(n_rows, dtype=torch.int32, device=indptr.device), + row_counts, + ) + + +def prepare_sddmm_csr(indices, indptr, shape, k_hint=64): + indices, indptr, shape = _prepare_sddmm_csr_pattern(indices, indptr, shape) + row_ids = _build_row_ids(indptr) + block_k, num_warps = _resolve_sddmm_launch_config(int(k_hint)) + return SDDMMPrepared( + indices=indices, + indptr=indptr, + shape=shape, + row_ids=row_ids, + block_k=block_k, + num_warps=num_warps, + ) + + +@triton.jit +def _sddmm_csr_real_kernel( + indices_ptr, + row_ids_ptr, + x_ptr, + y_ptr, + in_ptr, + out_ptr, + nnz, + k_dim, + stride_xm, + stride_xk, + stride_ym, + stride_yk, + alpha, + beta, + HAS_IN: tl.constexpr, + BLOCK_P: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_DTYPE: tl.constexpr, +): + pid = tl.program_id(0) + offs_p = pid * BLOCK_P + tl.arange(0, BLOCK_P) + mask_p = offs_p < nnz + + rows = tl.load(row_ids_ptr + offs_p, mask=mask_p, other=0) + cols = tl.load(indices_ptr + offs_p, mask=mask_p, other=0) + acc = tl.zeros([BLOCK_P], dtype=ACC_DTYPE) + + for k0 in tl.range(0, k_dim, BLOCK_K): + offs_k = k0 + tl.arange(0, BLOCK_K) + mask_k = offs_k < k_dim + x_ptrs = x_ptr + rows[:, None] * stride_xm + offs_k[None, :] * stride_xk + y_ptrs = y_ptr + cols[:, None] * stride_ym + offs_k[None, :] * stride_yk + xy_mask = mask_p[:, None] & mask_k[None, :] + x_vals = tl.load(x_ptrs, mask=xy_mask, other=0.0) + y_vals = tl.load(y_ptrs, mask=xy_mask, other=0.0) + acc += tl.sum(x_vals.to(ACC_DTYPE) * y_vals.to(ACC_DTYPE), axis=1) + + out_vals = acc * alpha + if HAS_IN: + in_vals = tl.load(in_ptr + offs_p, mask=mask_p, other=0.0).to(ACC_DTYPE) + out_vals += in_vals * beta + tl.store(out_ptr + offs_p, out_vals, mask=mask_p) + + +@triton.jit +def _sddmm_csr_real_kernel_altreduce( + indices_ptr, + row_ids_ptr, + x_ptr, + y_ptr, + in_ptr, + out_ptr, + nnz, + k_dim, + stride_xm, + stride_xk, + stride_ym, + stride_yk, + alpha, + beta, + HAS_IN: tl.constexpr, + BLOCK_P: tl.constexpr, + BLOCK_K: tl.constexpr, + ACC_DTYPE: tl.constexpr, +): + pid = tl.program_id(0) + offs_p = pid * BLOCK_P + tl.arange(0, BLOCK_P) + mask_p = offs_p < nnz + + rows = tl.load(row_ids_ptr + offs_p, mask=mask_p, other=0) + cols = tl.load(indices_ptr + offs_p, mask=mask_p, other=0) + acc = tl.zeros([BLOCK_P], dtype=ACC_DTYPE) + + for k0 in tl.range(0, k_dim, BLOCK_K): + for kk in tl.static_range(0, BLOCK_K): + k_idx = k0 + kk + valid_k = k_idx < k_dim + x_vals = tl.load( + x_ptr + rows * stride_xm + k_idx * stride_xk, + mask=mask_p & valid_k, + other=0.0, + ) + y_vals = tl.load( + y_ptr + cols * stride_ym + k_idx * stride_yk, + mask=mask_p & valid_k, + other=0.0, + ) + acc += x_vals.to(ACC_DTYPE) * y_vals.to(ACC_DTYPE) + + out_vals = acc * alpha + if HAS_IN: + in_vals = tl.load(in_ptr + offs_p, mask=mask_p, other=0.0).to(ACC_DTYPE) + out_vals += in_vals * beta + tl.store(out_ptr + offs_p, out_vals, mask=mask_p) + + +def _validate_sddmm_dense_inputs(data, prepared, x, y): + if x.ndim != 2 or y.ndim != 2: + raise ValueError("x and y must be 2D dense tensors") + if not x.is_cuda or not y.is_cuda: + raise ValueError("x and y must be CUDA tensors") + if x.device != y.device or x.device != prepared.indices.device: + raise ValueError("x, y, and sparse pattern must be on the same CUDA device") + if x.dtype not in SUPPORTED_SDDMM_VALUE_DTYPES: + raise TypeError("x dtype must be torch.float32 or torch.float64") + if y.dtype != x.dtype: + raise TypeError("y dtype must match x dtype") + if data is not None and data.dtype != x.dtype: + raise TypeError("data dtype must match x/y dtype") + if x.shape[0] != prepared.n_rows: + raise ValueError(f"x.shape[0] must be n_rows={prepared.n_rows}, got {x.shape[0]}") + if y.shape[0] != prepared.n_cols: + raise ValueError(f"y.shape[0] must be n_cols={prepared.n_cols}, got {y.shape[0]}") + if x.shape[1] != y.shape[1]: + raise ValueError("x and y must have the same K dimension") + if data is not None and data.numel() != prepared.nnz: + raise ValueError("data length must equal nnz of sparse pattern") + return int(x.shape[1]) + + +def _prepare_validated_sddmm_out(prepared, x, out, out_dtype=None): + nnz = prepared.nnz + target_dtype = x.dtype if out_dtype is None else out_dtype + if out is None: + return torch.empty(nnz, dtype=target_dtype, device=x.device) + if out.ndim != 1 or out.numel() != nnz: + raise ValueError("out must be a 1D tensor with length nnz") + if not out.is_cuda or out.device != x.device: + raise ValueError("out must be a CUDA tensor on the same device as x") + if out.dtype != target_dtype: + raise TypeError("out dtype must match the requested output dtype") + return out + + +def _normalize_sddmm_diagnostic_variant(variant): + if variant is None: + return "baseline" + variant = str(variant).strip().lower() + if variant not in SUPPORTED_SDDMM_DIAGNOSTIC_VARIANTS: + supported = ", ".join(SUPPORTED_SDDMM_DIAGNOSTIC_VARIANTS) + raise ValueError(f"Unsupported SDDMM diagnostic variant {variant!r}; expected one of: {supported}") + return variant + + +def _resolve_sddmm_diagnostic_kernel(variant, value_dtype): + variant = _normalize_sddmm_diagnostic_variant(variant) + if variant == "baseline": + acc_dtype = tl.float64 if value_dtype == torch.float64 else tl.float32 + return _sddmm_csr_real_kernel, acc_dtype + if variant in ("acc64", "acc64_out64"): + return _sddmm_csr_real_kernel, tl.float64 + acc_dtype = tl.float64 if value_dtype == torch.float64 else tl.float32 + return _sddmm_csr_real_kernel_altreduce, acc_dtype + + +def _resolve_sddmm_diagnostic_out_dtype(variant, value_dtype): + variant = _normalize_sddmm_diagnostic_variant(variant) + if variant == "acc64_out64": + return torch.float64 + return value_dtype + + +def _run_sddmm_prepared(prepared, x, y, data, alpha, beta, out, allow_fallback=False, variant="baseline", out_dtype=None): + nnz = prepared.nnz + variant = _normalize_sddmm_diagnostic_variant(variant) + target_out_dtype = _resolve_sddmm_diagnostic_out_dtype(variant, x.dtype) if out_dtype is None else out_dtype + out = _prepare_validated_sddmm_out(prepared, x, out, out_dtype=target_out_dtype) + if nnz == 0: + return out, { + "block_k": prepared.block_k, + "num_warps": prepared.num_warps, + "fallback_used": False, + "variant": variant, + "acc_dtype": "float64" if target_out_dtype == torch.float64 else "float32", + "out_dtype": str(target_out_dtype).replace("torch.", ""), + } + + k_dim = int(x.shape[1]) + block_k, num_warps = _resolve_sddmm_launch_config(k_dim) + block_p = 128 + kernel, acc_dtype = _resolve_sddmm_diagnostic_kernel(variant, x.dtype) + grid = (triton.cdiv(nnz, block_p),) + fallback_used = False + if allow_fallback: + try: + kernel[grid]( + prepared.indices, + prepared.row_ids, + x, + y, + data if data is not None else out, + out, + nnz, + k_dim, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + float(alpha), + float(beta), + HAS_IN=data is not None, + BLOCK_P=block_p, + BLOCK_K=block_k, + ACC_DTYPE=acc_dtype, + num_warps=num_warps, + ) + except Exception: + out.copy_(_sddmm_reference(prepared.indices, prepared.indptr, x, y, data, alpha, beta).to(out.dtype)) + fallback_used = True + else: + kernel[grid]( + prepared.indices, + prepared.row_ids, + x, + y, + data if data is not None else out, + out, + nnz, + k_dim, + x.stride(0), + x.stride(1), + y.stride(0), + y.stride(1), + float(alpha), + float(beta), + HAS_IN=data is not None, + BLOCK_P=block_p, + BLOCK_K=block_k, + ACC_DTYPE=acc_dtype, + num_warps=num_warps, + ) + return out, { + "block_k": block_k, + "num_warps": num_warps, + "fallback_used": fallback_used, + "variant": variant, + "acc_dtype": "float64" if acc_dtype == tl.float64 else "float32", + "out_dtype": str(out.dtype).replace("torch.", ""), + } + + +def flagsparse_sddmm_csr( + data=None, + indices=None, + indptr=None, + x=None, + y=None, + shape=None, + alpha=1.0, + beta=0.0, + prepared=None, + out=None, + return_time=False, + return_meta=False, + allow_fallback=False, +): + """CSR SDDMM: out[p] = alpha * dot(x[row(p)], y[col(p)]) + beta * data[p].""" + prepare_ms = 0.0 + if prepared is None: + if any(v is None for v in (indices, indptr, shape)): + raise ValueError("indices, indptr, and shape are required when prepared is not provided") + torch.cuda.synchronize() + t_prepare0 = time.perf_counter() + k_hint = int(x.shape[1]) if (x is not None and x.ndim == 2) else 64 + prepared = prepare_sddmm_csr(indices, indptr, shape, k_hint=k_hint) + torch.cuda.synchronize() + prepare_ms = (time.perf_counter() - t_prepare0) * 1000.0 + elif not isinstance(prepared, SDDMMPrepared): + raise TypeError("prepared must be a SDDMMPrepared instance") + + if x is None or y is None: + raise ValueError("x and y are required") + if data is None and float(beta) != 0.0: + raise ValueError("data is required when beta is non-zero") + k_dim = _validate_sddmm_dense_inputs(data, prepared, x, y) + if k_dim == 0: + out = _prepare_validated_sddmm_out(prepared, x, out) + if beta == 0.0 or data is None: + out.zero_() + else: + out.copy_(data * beta) + meta = {"prepare_ms": prepare_ms, "block_k": prepared.block_k, "num_warps": prepared.num_warps, "fallback_used": False} + if return_time and return_meta: + return out, 0.0, meta + if return_time: + return out, 0.0 + if return_meta: + return out, meta + return out + + torch.cuda.synchronize() + t0 = time.perf_counter() + out_tensor, launch_meta = _run_sddmm_prepared( + prepared, + x.contiguous(), + y.contiguous(), + data.contiguous() if data is not None else None, + alpha, + beta, + out, + allow_fallback=allow_fallback, + ) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + + if return_time and return_meta: + meta = {"prepare_ms": prepare_ms, **launch_meta} + return out_tensor, elapsed_ms, meta + if return_time: + return out_tensor, elapsed_ms + if return_meta: + meta = {"prepare_ms": prepare_ms, **launch_meta} + return out_tensor, meta + return out_tensor + + +def _sddmm_reference(indices, indptr, x, y, data, alpha, beta): + n_rows = int(indptr.numel()) - 1 + row_ids = torch.repeat_interleave( + torch.arange(n_rows, dtype=torch.int64, device=indices.device), + indptr[1:] - indptr[:-1], + ) + if row_ids.numel() == 0: + return torch.empty(0, dtype=x.dtype, device=x.device) + vals = torch.sum(x[row_ids] * y[indices.to(torch.int64)], dim=1) + vals = alpha * vals + if data is not None: + vals = vals + beta * data + return vals + + +def _cupy_sampled_dot_reference(indices, indptr, x, y, data, alpha, beta, chunk_nnz=262144): + _require_cupy() + n_rows = int(indptr.numel()) - 1 + row_ids = torch.repeat_interleave( + torch.arange(n_rows, dtype=torch.int64, device=indices.device), + indptr[1:] - indptr[:-1], + ) + row_ids_cp = _cupy_from_torch(row_ids) + col_ids_cp = _cupy_from_torch(indices.to(torch.int64)) + x_cp = _cupy_from_torch(x) + y_cp = _cupy_from_torch(y) + nnz = int(indices.numel()) + if nnz == 0: + vals = torch.empty(0, dtype=x.dtype, device=x.device) + if data is not None and beta != 0.0: + vals = vals + data * beta + return vals + + out_cp = cp.empty((nnz,), dtype=x_cp.dtype) + chunk_nnz = max(1, int(chunk_nnz)) + for start in range(0, nnz, chunk_nnz): + end = min(nnz, start + chunk_nnz) + rows = row_ids_cp[start:end] + cols = col_ids_cp[start:end] + out_cp[start:end] = cp.sum(x_cp[rows] * y_cp[cols], axis=1) + out = _torch_from_cupy(out_cp) + out = out * alpha + if data is not None and beta != 0.0: + out = out + data * beta + return out + + +def benchmark_sddmm_case( + n_rows=1024, + n_cols=1024, + nnz=16384, + k_dim=64, + value_dtype=torch.float32, + warmup=10, + iters=30, + alpha=1.0, + beta=0.0, + run_cusparse=False, +): + """Benchmark SDDMM and compare with sampled-dot reference.""" + if value_dtype not in SUPPORTED_SDDMM_VALUE_DTYPES: + raise TypeError("value_dtype must be torch.float32 or torch.float64") + device = torch.device("cuda") + data, indices, indptr = _build_random_csr( + n_rows, n_cols, nnz, value_dtype, torch.int32, device + ) + x = _build_random_dense((n_rows, k_dim), value_dtype, device) + y = _build_random_dense((n_cols, k_dim), value_dtype, device) + + prepared = prepare_sddmm_csr(indices, indptr, (n_rows, n_cols), k_hint=k_dim) + op = lambda: flagsparse_sddmm_csr( + data=data, + x=x, + y=y, + alpha=alpha, + beta=beta, + prepared=prepared, + return_time=False, + ) + triton_values, triton_ms = _benchmark_cuda_op(op, warmup=warmup, iters=iters) + ref_op = lambda: _sddmm_reference(indices, indptr.to(torch.int64), x, y, data, alpha, beta) + ref_values, pytorch_ms = _benchmark_cuda_op(ref_op, warmup=warmup, iters=iters) + + atol, rtol = _tolerance_for_dtype(value_dtype) + match = bool(torch.allclose(triton_values, ref_values, atol=atol, rtol=rtol)) + max_abs = ( + float(torch.max(torch.abs(triton_values - ref_values)).item()) + if triton_values.numel() > 0 + else 0.0 + ) + + cusparse_ms = None + cusparse_reason = None + cusparse_match = None + if run_cusparse: + if cp is None: + cusparse_reason = "CuPy is not available" + else: + try: + ref_cu, cusparse_ms = _benchmark_cuda_op( + lambda: _cupy_sampled_dot_reference( + indices=indices, + indptr=indptr.to(torch.int64), + x=x, + y=y, + data=data, + alpha=alpha, + beta=beta, + ), + warmup=warmup, + iters=iters, + ) + cusparse_match = bool(torch.allclose(triton_values, ref_cu, atol=atol, rtol=rtol)) + except Exception as exc: + cusparse_reason = str(exc) + + return { + "parameters": { + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": nnz, + "k_dim": k_dim, + "value_dtype": str(value_dtype), + "warmup": warmup, + "iters": iters, + "alpha": alpha, + "beta": beta, + }, + "performance": { + "triton_ms": triton_ms, + "pytorch_ms": pytorch_ms, + "cusparse_ms": cusparse_ms, + "triton_speedup_vs_pytorch": (pytorch_ms / triton_ms if triton_ms > 0 else None), + "triton_speedup_vs_cusparse": (cusparse_ms / triton_ms if (cusparse_ms and triton_ms > 0) else None), + }, + "verification": { + "triton_match_pytorch": match, + "triton_max_abs_error": max_abs, + "cusparse_match_pytorch": cusparse_match, + }, + "backend_status": { + "cusparse_unavailable_reason": cusparse_reason, + }, + "samples": { + "triton": triton_values, + "pytorch": ref_values, + }, + } diff --git a/src/flagsparse/sparse_operations/spgemm_csr.py b/src/flagsparse/sparse_operations/spgemm_csr.py new file mode 100644 index 0000000..a6378e1 --- /dev/null +++ b/src/flagsparse/sparse_operations/spgemm_csr.py @@ -0,0 +1,1313 @@ +"""CSR SpGEMM (A@B) with two-phase structure/value build.""" + +from ._common import * + +SUPPORTED_SPGEMM_VALUE_DTYPES = (torch.float32, torch.float64) +_SPGEMM_COUNT_MAX_EXPANDED = 2_000_000 +_SPGEMM_FILL_MAX_EXPANDED = 1_200_000 +_SPGEMM_MAX_ROWS_PER_CHUNK = 4096 +_SPGEMM_BUCKET_SHORT = 0 +_SPGEMM_BUCKET_MEDIUM = 1 +_SPGEMM_BUCKET_LONG = 2 +_SPGEMM_BUCKET_ORDER = ( + _SPGEMM_BUCKET_SHORT, + _SPGEMM_BUCKET_MEDIUM, + _SPGEMM_BUCKET_LONG, +) +_SPGEMM_BUCKET_LABELS = { + _SPGEMM_BUCKET_SHORT: "short", + _SPGEMM_BUCKET_MEDIUM: "medium", + _SPGEMM_BUCKET_LONG: "long", +} +_SPGEMM_BUCKET_COUNT_BUDGETS = { + _SPGEMM_BUCKET_SHORT: 4_000_000, + _SPGEMM_BUCKET_MEDIUM: _SPGEMM_COUNT_MAX_EXPANDED, + _SPGEMM_BUCKET_LONG: 300_000, +} +_SPGEMM_BUCKET_FILL_BUDGETS = { + _SPGEMM_BUCKET_SHORT: 2_400_000, + _SPGEMM_BUCKET_MEDIUM: _SPGEMM_FILL_MAX_EXPANDED, + _SPGEMM_BUCKET_LONG: 200_000, +} +_SPGEMM_BUCKET_MAX_ROWS = { + _SPGEMM_BUCKET_SHORT: 8192, + _SPGEMM_BUCKET_MEDIUM: _SPGEMM_MAX_ROWS_PER_CHUNK, + _SPGEMM_BUCKET_LONG: 256, +} +_SPGEMM_LONG_ROW_SLICE_EXPANDED = 200_000 + + +class SpGEMMPrepared: + """Prepared CSR metadata for repeated SpGEMM runs.""" + + __slots__ = ( + "a_data", + "a_indices", + "a_indptr", + "a_shape", + "b_data", + "b_indices", + "b_indptr", + "b_shape", + "n_rows", + "n_inner", + "n_cols", + "a_row_work", + "row_bucket", + "row_work_ready", + "bucket_rows", + "count_chunks", + "fill_chunks", + "count_chunks_by_bucket", + "fill_chunks_by_bucket", + "long_row_slice_expanded", + "long_row_slices_host", + "hash_capacity_hint", + "block_nnz", + ) + + def __init__( + self, + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + a_row_work, + row_bucket, + row_work_ready, + bucket_rows, + count_chunks, + fill_chunks, + count_chunks_by_bucket, + fill_chunks_by_bucket, + long_row_slice_expanded, + long_row_slices_host, + hash_capacity_hint, + block_nnz, + ): + self.a_data = a_data + self.a_indices = a_indices + self.a_indptr = a_indptr + self.a_shape = (int(a_shape[0]), int(a_shape[1])) + self.b_data = b_data + self.b_indices = b_indices + self.b_indptr = b_indptr + self.b_shape = (int(b_shape[0]), int(b_shape[1])) + self.n_rows = self.a_shape[0] + self.n_inner = self.a_shape[1] + self.n_cols = self.b_shape[1] + self.a_row_work = a_row_work + self.row_bucket = row_bucket + self.row_work_ready = bool(row_work_ready) + self.bucket_rows = bucket_rows + self.count_chunks = count_chunks + self.fill_chunks = fill_chunks + self.count_chunks_by_bucket = count_chunks_by_bucket + self.fill_chunks_by_bucket = fill_chunks_by_bucket + self.long_row_slice_expanded = int(long_row_slice_expanded) + self.long_row_slices_host = long_row_slices_host + self.hash_capacity_hint = int(hash_capacity_hint) + self.block_nnz = int(block_nnz) + + +def _validate_csr(data, indices, indptr, shape, tag): + if len(shape) != 2: + raise ValueError(f"{tag}_shape must be a 2-tuple") + if data.ndim != 1 or indices.ndim != 1 or indptr.ndim != 1: + raise ValueError(f"{tag} data/indices/indptr must be 1D tensors") + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows < 0 or n_cols < 0: + raise ValueError(f"{tag}_shape dimensions must be non-negative") + if indptr.numel() != n_rows + 1: + raise ValueError( + f"{tag}_indptr length must be n_rows+1={n_rows + 1}, got {indptr.numel()}" + ) + if data.numel() != indices.numel(): + raise ValueError(f"{tag}_data and {tag}_indices must have the same length") + if not data.is_cuda or not indices.is_cuda or not indptr.is_cuda: + raise ValueError(f"{tag} tensors must be CUDA tensors") + if data.dtype not in SUPPORTED_SPGEMM_VALUE_DTYPES: + raise TypeError(f"{tag}_data dtype must be torch.float32 or torch.float64") + if indices.dtype != torch.int32: + raise TypeError(f"{tag}_indices dtype must be torch.int32") + if indptr.dtype not in (torch.int32, torch.int64): + raise TypeError(f"{tag}_indptr dtype must be torch.int32 or torch.int64") + + nnz = int(data.numel()) + indptr_i64 = indptr.to(torch.int64) + if indptr_i64.numel() > 0 and int(indptr_i64[0].item()) != 0: + raise ValueError(f"{tag}_indptr[0] must be 0") + if indptr_i64.numel() > 0 and int(indptr_i64[-1].item()) != nnz: + raise ValueError(f"{tag}_indptr[-1] must equal nnz={nnz}") + if indptr_i64.numel() > 1 and bool(torch.any(indptr_i64[1:] < indptr_i64[:-1]).item()): + raise ValueError(f"{tag}_indptr must be nondecreasing") + if nnz > 0: + min_col = int(indices.min().item()) + max_col = int(indices.max().item()) + if min_col < 0 or max_col >= n_cols: + raise IndexError(f"{tag}_indices out of range for n_cols={n_cols}") + return n_rows, n_cols, indptr_i64 + + +def _prepare_spgemm_csr_inputs( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, +): + a_rows, a_cols, a_indptr64 = _validate_csr(a_data, a_indices, a_indptr, a_shape, "a") + b_rows, b_cols, b_indptr64 = _validate_csr(b_data, b_indices, b_indptr, b_shape, "b") + if a_cols != b_rows: + raise ValueError( + f"shape mismatch for A@B: A is {a_rows}x{a_cols}, B is {b_rows}x{b_cols}" + ) + if a_data.device != b_data.device: + raise ValueError("A and B tensors must be on the same CUDA device") + if a_data.dtype != b_data.dtype: + raise TypeError("A and B value dtype must match") + + a_data = a_data.contiguous() + a_indices = a_indices.contiguous() + a_indptr64 = a_indptr64.contiguous() + b_data = b_data.contiguous() + b_indices = b_indices.contiguous() + b_indptr64 = b_indptr64.contiguous() + return ( + a_data, + a_indices, + a_indptr64, + (a_rows, a_cols), + b_data, + b_indices, + b_indptr64, + (b_rows, b_cols), + ) + + +@triton.jit +def _spgemm_row_work_kernel( + a_indptr_ptr, + a_indices_ptr, + b_indptr_ptr, + row_work_ptr, + n_rows, + BLOCK_NNZ: tl.constexpr, +): + row = tl.program_id(0) + if row >= n_rows: + return + start = tl.load(a_indptr_ptr + row) + end = tl.load(a_indptr_ptr + row + 1) + row_nnz = end - start + acc = tl.zeros((), dtype=tl.int32) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + offs = start + chunk_start + tl.arange(0, BLOCK_NNZ) + mask = offs < end + k = tl.load(a_indices_ptr + offs, mask=mask, other=0) + b_start = tl.load(b_indptr_ptr + k, mask=mask, other=0) + b_end = tl.load(b_indptr_ptr + k + 1, mask=mask, other=0) + contrib = (b_end - b_start).to(tl.int32) + acc += tl.sum(tl.where(mask, contrib, 0)) + tl.store(row_work_ptr + row, acc) + + +def _estimate_hash_capacity(a_row_work): + if a_row_work.numel() == 0: + return 256 + p95 = int(torch.quantile(a_row_work.to(torch.float32), 0.95).item()) + p95 = max(p95, 1) + cap = 1 + while cap < p95: + cap <<= 1 + return max(256, cap) + + +def _build_row_bucket(a_row_work): + # 0: small, 1: medium, 2: long + bucket = torch.zeros_like(a_row_work, dtype=torch.int8) + bucket = torch.where( + a_row_work > 4096, + torch.full_like(bucket, _SPGEMM_BUCKET_LONG), + bucket, + ) + bucket = torch.where( + (a_row_work > 256) & (a_row_work <= 4096), + torch.full_like(bucket, _SPGEMM_BUCKET_MEDIUM), + bucket, + ) + return bucket + + +def _build_long_row_slices_host(a_indptr, a_indices, b_indptr, row_ids, max_expanded): + if row_ids is None or row_ids.numel() == 0: + return {} + + out = {} + row_list = row_ids.to(torch.int64).cpu().tolist() + a_indptr_cpu = a_indptr.to(torch.int64).cpu() + for row in row_list: + start = int(a_indptr_cpu[row].item()) + end = int(a_indptr_cpu[row + 1].item()) + if end <= start: + out[int(row)] = [] + continue + a_cols = a_indices[start:end].to(torch.int64) + b_counts = (b_indptr[a_cols + 1] - b_indptr[a_cols]).to(torch.int64) + counts_host = b_counts.cpu().tolist() + slices = [] + idx = 0 + total = len(counts_host) + while idx < total: + seg_start = idx + acc = 0 + while idx < total: + w = int(counts_host[idx]) + if idx > seg_start and acc + w > int(max_expanded): + break + acc += w + idx += 1 + if idx == seg_start + 1 and w > int(max_expanded): + break + slices.append((start + seg_start, start + idx)) + out[int(row)] = slices + return out + + +def prepare_spgemm_csr( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + block_nnz=256, + analyze_rows=True, +): + if block_nnz <= 0: + raise ValueError("block_nnz must be positive") + ( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + ) = _prepare_spgemm_csr_inputs( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + ) + n_rows = int(a_shape[0]) + if n_rows == 0: + row_work = torch.empty(0, dtype=torch.int32, device=a_data.device) + row_bucket = torch.empty(0, dtype=torch.int8, device=a_data.device) + hash_capacity_hint = 256 + row_work_ready = True + elif analyze_rows: + row_work = torch.empty(n_rows, dtype=torch.int32, device=a_data.device) + try: + _spgemm_row_work_kernel[(n_rows,)]( + a_indptr, + a_indices, + b_indptr, + row_work, + n_rows, + BLOCK_NNZ=int(block_nnz), + ) + except Exception: + b_row_nnz = (b_indptr[1:] - b_indptr[:-1]).to(torch.int32) + for row in range(n_rows): + start = int(a_indptr[row].item()) + end = int(a_indptr[row + 1].item()) + if end <= start: + row_work[row] = 0 + continue + cols = a_indices[start:end].to(torch.int64) + row_work[row] = torch.sum(b_row_nnz[cols]).to(torch.int32) + row_bucket = _build_row_bucket(row_work) + hash_capacity_hint = 256 + row_work_ready = True + else: + row_work = torch.zeros(n_rows, dtype=torch.int32, device=a_data.device) + row_bucket = torch.zeros(n_rows, dtype=torch.int8, device=a_data.device) + hash_capacity_hint = 256 + row_work_ready = False + long_row_slices_host = {} + if row_work_ready: + long_rows = torch.nonzero(row_bucket == _SPGEMM_BUCKET_LONG, as_tuple=False).flatten() + long_row_slices_host = _build_long_row_slices_host( + a_indptr, + a_indices, + b_indptr, + long_rows, + _SPGEMM_LONG_ROW_SLICE_EXPANDED, + ) + return SpGEMMPrepared( + a_data=a_data, + a_indices=a_indices, + a_indptr=a_indptr, + a_shape=a_shape, + b_data=b_data, + b_indices=b_indices, + b_indptr=b_indptr, + b_shape=b_shape, + a_row_work=row_work, + row_bucket=row_bucket, + row_work_ready=row_work_ready, + bucket_rows=None, + count_chunks=None, + fill_chunks=None, + count_chunks_by_bucket=None, + fill_chunks_by_bucket=None, + long_row_slice_expanded=_SPGEMM_LONG_ROW_SLICE_EXPANDED, + long_row_slices_host=long_row_slices_host, + hash_capacity_hint=hash_capacity_hint, + block_nnz=block_nnz, + ) + + +def _ensure_row_work(prepared): + if prepared.row_work_ready: + return + if prepared.a_data.numel() == 0 or prepared.b_data.numel() == 0: + prepared.a_row_work = torch.zeros(prepared.n_rows, dtype=torch.int32, device=prepared.a_data.device) + prepared.row_bucket = torch.zeros(prepared.n_rows, dtype=torch.int8, device=prepared.a_data.device) + prepared.hash_capacity_hint = 256 + prepared.row_work_ready = True + prepared.long_row_slices_host = {} + _clear_runtime_schedules(prepared) + return + row_work = torch.empty(prepared.n_rows, dtype=torch.int32, device=prepared.a_data.device) + _spgemm_row_work_kernel[(prepared.n_rows,)]( + prepared.a_indptr, + prepared.a_indices, + prepared.b_indptr, + row_work, + prepared.n_rows, + BLOCK_NNZ=int(prepared.block_nnz), + ) + prepared.a_row_work = row_work + prepared.row_bucket = _build_row_bucket(row_work) + long_rows = torch.nonzero(prepared.row_bucket == _SPGEMM_BUCKET_LONG, as_tuple=False).flatten() + prepared.long_row_slices_host = _build_long_row_slices_host( + prepared.a_indptr, + prepared.a_indices, + prepared.b_indptr, + long_rows, + prepared.long_row_slice_expanded, + ) + prepared.hash_capacity_hint = 256 + prepared.row_work_ready = True + _clear_runtime_schedules(prepared) + + +def _build_row_chunks(row_work, max_expanded, max_rows_per_chunk): + n_rows = int(row_work.numel()) + if n_rows == 0: + return [] + all_rows = torch.arange(n_rows, device=row_work.device, dtype=torch.int64) + return _build_row_id_chunks(row_work, all_rows, max_expanded, max_rows_per_chunk) + + +def _build_bucket_rows(row_bucket, device): + out = {} + for bucket_id in _SPGEMM_BUCKET_ORDER: + rows = torch.nonzero(row_bucket == bucket_id, as_tuple=False).flatten() + out[bucket_id] = rows.to(device=device, dtype=torch.int64) + return out + + +def _build_row_id_chunks(row_work, row_ids, max_expanded, max_rows_per_chunk): + if row_ids.numel() == 0: + return [] + work_host = row_work[row_ids].detach().to("cpu", dtype=torch.int64).tolist() + chunks = [] + idx = 0 + total_rows = len(work_host) + while idx < total_rows: + start_idx = idx + acc = 0 + taken = 0 + while idx < total_rows and taken < int(max_rows_per_chunk): + w = int(work_host[idx]) + if taken > 0 and acc + w > int(max_expanded): + break + acc += w + idx += 1 + taken += 1 + if taken == 1 and w > int(max_expanded): + break + if idx == start_idx: + idx += 1 + chunks.append(row_ids[start_idx:idx].contiguous()) + return chunks + + +def _compose_ordered_chunks(chunks_by_bucket): + ordered = [] + for bucket_id in _SPGEMM_BUCKET_ORDER: + ordered.extend(chunks_by_bucket.get(bucket_id, [])) + return ordered + + +def _clear_runtime_schedules(prepared): + prepared.bucket_rows = None + prepared.count_chunks = None + prepared.fill_chunks = None + prepared.count_chunks_by_bucket = None + prepared.fill_chunks_by_bucket = None + + +def _clear_count_schedule(prepared): + prepared.count_chunks = None + prepared.count_chunks_by_bucket = None + + +def _clear_fill_schedule(prepared): + prepared.fill_chunks = None + prepared.fill_chunks_by_bucket = None + + +def _ensure_bucket_rows(prepared): + if prepared.bucket_rows is None: + prepared.bucket_rows = _build_bucket_rows(prepared.row_bucket, prepared.a_data.device) + return prepared.bucket_rows + + +def _chunk_rows_for_bucket(prepared, bucket_id, max_expanded): + rows = _ensure_bucket_rows(prepared)[bucket_id] + chunks = _build_row_id_chunks( + prepared.a_row_work, + rows, + max_expanded=max_expanded, + max_rows_per_chunk=_SPGEMM_BUCKET_MAX_ROWS[bucket_id], + ) + if bucket_id == _SPGEMM_BUCKET_LONG: + return [chunk.to(torch.int64).cpu().tolist() for chunk in chunks] + return chunks + + +def _ensure_count_chunks(prepared): + if prepared.count_chunks_by_bucket is not None: + return prepared.count_chunks_by_bucket + prepared.count_chunks_by_bucket = {} + for bucket_id in _SPGEMM_BUCKET_ORDER: + prepared.count_chunks_by_bucket[bucket_id] = _chunk_rows_for_bucket( + prepared, + bucket_id, + _SPGEMM_BUCKET_COUNT_BUDGETS[bucket_id], + ) + prepared.count_chunks = _compose_ordered_chunks(prepared.count_chunks_by_bucket) + return prepared.count_chunks_by_bucket + + +def _ensure_fill_chunks(prepared): + if prepared.fill_chunks_by_bucket is not None: + return prepared.fill_chunks_by_bucket + prepared.fill_chunks_by_bucket = {} + for bucket_id in _SPGEMM_BUCKET_ORDER: + prepared.fill_chunks_by_bucket[bucket_id] = _chunk_rows_for_bucket( + prepared, + bucket_id, + _SPGEMM_BUCKET_FILL_BUDGETS[bucket_id], + ) + prepared.fill_chunks = _compose_ordered_chunks(prepared.fill_chunks_by_bucket) + return prepared.fill_chunks_by_bucket + + +def _expand_rows_contrib(prepared, row_ids, need_values): + device = prepared.a_data.device + if row_ids.numel() == 0: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + if need_values: + empty_v = torch.empty(0, dtype=prepared.a_data.dtype, device=device) + return empty_i64, empty_v + return empty_i64, None + + row_ids = row_ids.to(torch.int64) + row_nnz = (prepared.a_indptr[row_ids + 1] - prepared.a_indptr[row_ids]).to(torch.int64) + total_a = int(row_nnz.sum().item()) + if total_a == 0: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + if need_values: + empty_v = torch.empty(0, dtype=prepared.a_data.dtype, device=device) + return empty_i64, empty_v + return empty_i64, None + + owner = torch.repeat_interleave( + torch.arange(row_ids.numel(), device=device, dtype=torch.int64), + row_nnz, + ) + prefix = torch.cumsum(row_nnz, dim=0) + base = prefix - row_nnz + intra = ( + torch.arange(total_a, device=device, dtype=torch.int64) + - torch.repeat_interleave(base, row_nnz) + ) + row_starts = prepared.a_indptr[row_ids] + a_pos = row_starts[owner] + intra + rows = row_ids[owner] + a_cols = prepared.a_indices[a_pos].to(torch.int64) + b_starts = prepared.b_indptr[a_cols] + b_ends = prepared.b_indptr[a_cols + 1] + b_counts = b_ends - b_starts + total = int(b_counts.sum().item()) + if total == 0: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + if need_values: + empty_v = torch.empty(0, dtype=prepared.a_data.dtype, device=device) + return empty_i64, empty_v + return empty_i64, None + + b_owner = torch.repeat_interleave( + torch.arange(a_cols.numel(), device=device, dtype=torch.int64), + b_counts, + ) + rows_expanded = rows[b_owner] + prefix = torch.cumsum(b_counts, dim=0) + base = prefix - b_counts + starts_rep = torch.repeat_interleave(b_starts, b_counts) + intra = ( + torch.arange(total, device=device, dtype=torch.int64) + - torch.repeat_interleave(base, b_counts) + ) + b_pos = starts_rep + intra + cols = prepared.b_indices[b_pos].to(torch.int64) + keys = rows_expanded * max(1, prepared.n_cols) + cols + if not need_values: + return keys, None + + a_vals = prepared.a_data[a_pos] + vals = a_vals[b_owner] * prepared.b_data[b_pos] + return keys, vals + + +def _expand_single_row_slice_contrib(prepared, row, a_ptr_start, a_ptr_end, need_values): + device = prepared.a_data.device + if a_ptr_end <= a_ptr_start: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + if need_values: + empty_v = torch.empty(0, dtype=prepared.a_data.dtype, device=device) + return empty_i64, empty_v + return empty_i64, None + + a_pos = torch.arange(a_ptr_start, a_ptr_end, device=device, dtype=torch.int64) + a_cols = prepared.a_indices[a_pos].to(torch.int64) + b_starts = prepared.b_indptr[a_cols] + b_ends = prepared.b_indptr[a_cols + 1] + b_counts = b_ends - b_starts + total = int(b_counts.sum().item()) + if total == 0: + empty_i64 = torch.empty(0, dtype=torch.int64, device=device) + if need_values: + empty_v = torch.empty(0, dtype=prepared.a_data.dtype, device=device) + return empty_i64, empty_v + return empty_i64, None + + owner = torch.repeat_interleave( + torch.arange(a_pos.numel(), device=device, dtype=torch.int64), + b_counts, + ) + prefix = torch.cumsum(b_counts, dim=0) + base = prefix - b_counts + starts_rep = torch.repeat_interleave(b_starts, b_counts) + intra = ( + torch.arange(total, device=device, dtype=torch.int64) + - torch.repeat_interleave(base, b_counts) + ) + b_pos = starts_rep + intra + cols = prepared.b_indices[b_pos].to(torch.int64) + keys = int(row) * max(1, prepared.n_cols) + cols + if not need_values: + return keys, None + a_vals = prepared.a_data[a_pos] + vals = a_vals[owner] * prepared.b_data[b_pos] + return keys, vals + + +def _iter_row_a_slices(prepared, row, max_expanded): + cached = prepared.long_row_slices_host.get(int(row)) if prepared.long_row_slices_host is not None else None + if cached is not None: + return cached + start = int(prepared.a_indptr[row].item()) + end = int(prepared.a_indptr[row + 1].item()) + if end <= start: + if prepared.long_row_slices_host is not None: + prepared.long_row_slices_host[int(row)] = [] + return [] + a_cols = prepared.a_indices[start:end].to(torch.int64) + b_counts = (prepared.b_indptr[a_cols + 1] - prepared.b_indptr[a_cols]).to(torch.int64) + counts_host = b_counts.cpu().tolist() + slices = [] + idx = 0 + total = len(counts_host) + while idx < total: + seg_start = idx + acc = 0 + while idx < total: + w = int(counts_host[idx]) + if idx > seg_start and acc + w > int(max_expanded): + break + acc += w + idx += 1 + if idx == seg_start + 1 and w > int(max_expanded): + break + slices.append((start + seg_start, start + idx)) + if prepared.long_row_slices_host is not None: + prepared.long_row_slices_host[int(row)] = slices + return slices + + +def _reduce_sorted_keys_vals(keys_sorted, vals_sorted, out_dtype): + if keys_sorted.numel() == 0: + return keys_sorted, vals_sorted + uniq_keys, counts = torch.unique_consecutive(keys_sorted, return_counts=True) + if uniq_keys.numel() == keys_sorted.numel(): + return uniq_keys, vals_sorted.to(out_dtype) + acc_dtype = torch.float64 if out_dtype == torch.float32 else vals_sorted.dtype + vals_acc = vals_sorted.to(acc_dtype) + prefix = torch.cumsum(vals_acc, dim=0) + end_idx = torch.cumsum(counts.to(torch.int64), dim=0) - 1 + seg_end = prefix[end_idx] + seg_begin = torch.zeros_like(seg_end) + if seg_end.numel() > 1: + seg_begin[1:] = prefix[end_idx[:-1]] + uniq_vals = (seg_end - seg_begin).to(out_dtype) + return uniq_keys, uniq_vals + + +def _sort_reduce_pairs(keys, vals, out_dtype): + if keys.numel() == 0: + return keys, vals + order = torch.argsort(keys) + keys_sorted = keys[order] + vals_sorted = vals[order] + return _reduce_sorted_keys_vals(keys_sorted, vals_sorted, out_dtype) + + +def _spgemm_count_phase(prepared, profile=False): + n_rows = prepared.n_rows + device = prepared.a_data.device + row_nnz_c = torch.zeros(n_rows, dtype=torch.int64, device=device) + bucket_ms = {bucket_id: None for bucket_id in _SPGEMM_BUCKET_ORDER} + long_row_sliced = 0 + chunks_by_bucket = _ensure_count_chunks(prepared) + for bucket_id in _SPGEMM_BUCKET_ORDER: + chunks = chunks_by_bucket.get(bucket_id, []) + for row_ids in chunks: + if profile: + torch.cuda.synchronize() + t_bucket0 = time.perf_counter() + if bucket_id != _SPGEMM_BUCKET_LONG: + keys, _ = _expand_rows_contrib(prepared, row_ids, need_values=False) + if keys.numel() > 0: + uniq_keys = torch.unique(keys, sorted=True) + uniq_rows = torch.div( + uniq_keys, + max(1, prepared.n_cols), + rounding_mode="floor", + ) + rows_unique, counts = torch.unique_consecutive( + uniq_rows, return_counts=True + ) + row_nnz_c[rows_unique] = counts.to(torch.int64) + else: + for row in row_ids: + slices = _iter_row_a_slices( + prepared, + int(row), + max_expanded=prepared.long_row_slice_expanded, + ) + if len(slices) > 1: + long_row_sliced += 1 + keys_parts = [] + for a_start, a_end in slices: + keys, _ = _expand_single_row_slice_contrib( + prepared, + int(row), + a_start, + a_end, + need_values=False, + ) + if keys.numel() == 0: + continue + keys_parts.append(torch.unique(keys, sorted=True)) + if not keys_parts: + row_nnz_c[int(row)] = 0 + continue + uniq_row_keys = torch.unique(torch.cat(keys_parts), sorted=True) + row_nnz_c[int(row)] = int(uniq_row_keys.numel()) + if profile: + torch.cuda.synchronize() + elapsed = (time.perf_counter() - t_bucket0) * 1000.0 + bucket_ms[bucket_id] = ( + elapsed if bucket_ms[bucket_id] is None else bucket_ms[bucket_id] + elapsed + ) + meta = { + "bucket_count_ms_short": bucket_ms[_SPGEMM_BUCKET_SHORT], + "bucket_count_ms_medium": bucket_ms[_SPGEMM_BUCKET_MEDIUM], + "bucket_count_ms_long": bucket_ms[_SPGEMM_BUCKET_LONG], + "long_row_sliced_count_count": int(long_row_sliced), + } + return row_nnz_c, meta + + +def _spgemm_fill_phase(prepared, c_indptr, out_data=None, out_indices=None, profile=False): + nnz_c = int(c_indptr[-1].item()) + device = prepared.a_data.device + c_data = ( + out_data + if out_data is not None + else torch.empty(nnz_c, dtype=prepared.a_data.dtype, device=device) + ) + c_indices = ( + out_indices + if out_indices is not None + else torch.empty(nnz_c, dtype=torch.int32, device=device) + ) + bucket_ms = {bucket_id: None for bucket_id in _SPGEMM_BUCKET_ORDER} + long_row_sliced = 0 + if nnz_c == 0: + meta = { + "bucket_fill_ms_short": bucket_ms[_SPGEMM_BUCKET_SHORT], + "bucket_fill_ms_medium": bucket_ms[_SPGEMM_BUCKET_MEDIUM], + "bucket_fill_ms_long": bucket_ms[_SPGEMM_BUCKET_LONG], + "long_row_sliced_count_fill": int(long_row_sliced), + } + return c_data, c_indices, meta + + chunks_by_bucket = _ensure_fill_chunks(prepared) + for bucket_id in _SPGEMM_BUCKET_ORDER: + chunks = chunks_by_bucket.get(bucket_id, []) + for row_ids in chunks: + if profile: + torch.cuda.synchronize() + t_bucket0 = time.perf_counter() + if bucket_id != _SPGEMM_BUCKET_LONG: + keys, vals = _expand_rows_contrib(prepared, row_ids, need_values=True) + if keys.numel() > 0: + uniq_keys, uniq_vals = _sort_reduce_pairs( + keys, + vals, + out_dtype=prepared.a_data.dtype, + ) + uniq_rows = torch.div( + uniq_keys, + max(1, prepared.n_cols), + rounding_mode="floor", + ) + uniq_cols = (uniq_keys - uniq_rows * max(1, prepared.n_cols)).to(torch.int32) + _, row_counts = torch.unique_consecutive(uniq_rows, return_counts=True) + row_offsets = torch.cumsum(row_counts.to(torch.int64), dim=0) - row_counts.to(torch.int64) + local_pos = ( + torch.arange(uniq_keys.numel(), device=device, dtype=torch.int64) + - torch.repeat_interleave(row_offsets, row_counts) + ) + dst = c_indptr[uniq_rows] + local_pos + c_indices[dst] = uniq_cols + c_data[dst] = uniq_vals + else: + for row in row_ids: + row = int(row) + slices = _iter_row_a_slices( + prepared, + row, + max_expanded=prepared.long_row_slice_expanded, + ) + if len(slices) > 1: + long_row_sliced += 1 + key_parts = [] + val_parts = [] + for a_start, a_end in slices: + keys, vals = _expand_single_row_slice_contrib( + prepared, + row, + a_start, + a_end, + need_values=True, + ) + if keys.numel() == 0: + continue + uniq_k, uniq_v = _sort_reduce_pairs( + keys, + vals, + out_dtype=prepared.a_data.dtype, + ) + key_parts.append(uniq_k) + val_parts.append(uniq_v) + row_start = int(c_indptr[row].item()) + row_end = int(c_indptr[row + 1].item()) + row_nnz = row_end - row_start + if row_nnz == 0: + continue + if not key_parts: + raise RuntimeError(f"row {row} expected nnz={row_nnz} but got empty fill") + row_keys = torch.cat(key_parts) + row_vals = torch.cat(val_parts) + row_keys, row_vals = _sort_reduce_pairs( + row_keys, + row_vals, + out_dtype=prepared.a_data.dtype, + ) + if row_keys.numel() != row_nnz: + raise RuntimeError( + f"row {row} fill nnz mismatch: expected {row_nnz}, got {row_keys.numel()}" + ) + row_cols = (row_keys - row * max(1, prepared.n_cols)).to(torch.int32) + c_indices[row_start:row_end] = row_cols + c_data[row_start:row_end] = row_vals + if profile: + torch.cuda.synchronize() + elapsed = (time.perf_counter() - t_bucket0) * 1000.0 + bucket_ms[bucket_id] = ( + elapsed if bucket_ms[bucket_id] is None else bucket_ms[bucket_id] + elapsed + ) + + meta = { + "bucket_fill_ms_short": bucket_ms[_SPGEMM_BUCKET_SHORT], + "bucket_fill_ms_medium": bucket_ms[_SPGEMM_BUCKET_MEDIUM], + "bucket_fill_ms_long": bucket_ms[_SPGEMM_BUCKET_LONG], + "long_row_sliced_count_fill": int(long_row_sliced), + } + return c_data, c_indices, meta + + +def _run_spgemm_prepared(prepared, out=None, profile=False, measure_stage=False): + _ensure_row_work(prepared) + if out is not None: + if not isinstance(out, (tuple, list)) or len(out) != 3: + raise TypeError("out must be a tuple/list of (data, indices, indptr)") + out_data, out_indices, out_indptr = out + if not out_data.is_cuda or not out_indices.is_cuda or not out_indptr.is_cuda: + raise ValueError("out data/indices/indptr must be CUDA tensors") + if out_data.device != prepared.a_data.device or out_indices.device != prepared.a_data.device or out_indptr.device != prepared.a_data.device: + raise ValueError("out data/indices/indptr must be on the same CUDA device as computed C") + if out_indptr.shape != (prepared.n_rows + 1,) or out_indptr.dtype != torch.int64: + raise ValueError("out indptr shape/dtype must match computed C indptr") + else: + out_data = out_indices = out_indptr = None + + if measure_stage: + torch.cuda.synchronize() + t_count0 = time.perf_counter() + row_nnz_c, count_meta = _spgemm_count_phase(prepared, profile=profile) + if measure_stage: + torch.cuda.synchronize() + count_ms = (time.perf_counter() - t_count0) * 1000.0 + else: + count_ms = None + _clear_count_schedule(prepared) + + c_indptr = out_indptr + if c_indptr is None: + c_indptr = torch.empty(prepared.n_rows + 1, dtype=torch.int64, device=prepared.a_data.device) + c_indptr[0] = 0 + if prepared.n_rows > 0: + c_indptr[1:] = torch.cumsum(row_nnz_c, dim=0) + nnz_c = int(c_indptr[-1].item()) if c_indptr.numel() > 0 else 0 + + if out_data is not None: + if out_data.shape != (nnz_c,) or out_data.dtype != prepared.a_data.dtype: + raise ValueError("out data shape/dtype must match computed C data") + if out_indices.shape != (nnz_c,) or out_indices.dtype != torch.int32: + raise ValueError("out indices shape/dtype must match computed C indices") + + if measure_stage: + torch.cuda.synchronize() + t_fill0 = time.perf_counter() + c_data, c_indices, fill_meta = _spgemm_fill_phase( + prepared, + c_indptr, + out_data=out_data, + out_indices=out_indices, + profile=profile, + ) + if measure_stage: + torch.cuda.synchronize() + fill_ms = (time.perf_counter() - t_fill0) * 1000.0 + else: + fill_ms = None + _clear_fill_schedule(prepared) + + def _sum_bucket_ms(count_key, fill_key): + count_val = count_meta[count_key] + fill_val = fill_meta[fill_key] + if count_val is None or fill_val is None: + return None + return float(count_val + fill_val) + + return c_data, c_indices, c_indptr, { + "count_ms": count_ms, + "fill_ms": fill_ms, + "bucket_ms_short": _sum_bucket_ms("bucket_count_ms_short", "bucket_fill_ms_short"), + "bucket_ms_medium": _sum_bucket_ms("bucket_count_ms_medium", "bucket_fill_ms_medium"), + "bucket_ms_long": _sum_bucket_ms("bucket_count_ms_long", "bucket_fill_ms_long"), + "bucket_count_ms_short": count_meta["bucket_count_ms_short"], + "bucket_count_ms_medium": count_meta["bucket_count_ms_medium"], + "bucket_count_ms_long": count_meta["bucket_count_ms_long"], + "bucket_fill_ms_short": fill_meta["bucket_fill_ms_short"], + "bucket_fill_ms_medium": fill_meta["bucket_fill_ms_medium"], + "bucket_fill_ms_long": fill_meta["bucket_fill_ms_long"], + "bucket_nrows_short": int(torch.count_nonzero(prepared.row_bucket == _SPGEMM_BUCKET_SHORT).item()), + "bucket_nrows_medium": int(torch.count_nonzero(prepared.row_bucket == _SPGEMM_BUCKET_MEDIUM).item()), + "bucket_nrows_long": int(torch.count_nonzero(prepared.row_bucket == _SPGEMM_BUCKET_LONG).item()), + "long_row_sliced_count": int( + max( + count_meta["long_row_sliced_count_count"], + fill_meta["long_row_sliced_count_fill"], + ) + ), + } + + +def flagsparse_spgemm_csr( + a_data=None, + a_indices=None, + a_indptr=None, + a_shape=None, + b_data=None, + b_indices=None, + b_indptr=None, + b_shape=None, + prepared=None, + out=None, + return_time=False, + return_meta=False, +): + """CSR SpGEMM: C = A @ B with CSR output (Triton-only main path).""" + prepare_ms = 0.0 + internal_prepared = prepared is None + if prepared is None: + if any( + x is None + for x in ( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + ) + ): + raise ValueError( + "A/B CSR tensors and shapes are required when prepared is not provided" + ) + if return_meta: + torch.cuda.synchronize() + t_prepare0 = time.perf_counter() + prepared = prepare_spgemm_csr( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + ) + if return_meta: + torch.cuda.synchronize() + prepare_ms = (time.perf_counter() - t_prepare0) * 1000.0 + elif not isinstance(prepared, SpGEMMPrepared): + raise TypeError("prepared must be a SpGEMMPrepared instance") + + elapsed_ms = None + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + c_data, c_indices, c_indptr, stage_meta = _run_spgemm_prepared( + prepared, + out=out, + profile=bool(return_meta), + measure_stage=bool(return_meta), + ) + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + + if internal_prepared: + _clear_runtime_schedules(prepared) + + result = (c_data, c_indices, c_indptr, (prepared.n_rows, prepared.n_cols)) + if return_time and return_meta: + meta = { + "prepare_ms": prepare_ms, + "count_ms": stage_meta["count_ms"], + "fill_ms": stage_meta["fill_ms"], + "triton_ms": elapsed_ms, + "hash_capacity_hint": prepared.hash_capacity_hint, + "bucket_ms_short": stage_meta["bucket_ms_short"], + "bucket_ms_medium": stage_meta["bucket_ms_medium"], + "bucket_ms_long": stage_meta["bucket_ms_long"], + "bucket_count_ms_short": stage_meta["bucket_count_ms_short"], + "bucket_count_ms_medium": stage_meta["bucket_count_ms_medium"], + "bucket_count_ms_long": stage_meta["bucket_count_ms_long"], + "bucket_fill_ms_short": stage_meta["bucket_fill_ms_short"], + "bucket_fill_ms_medium": stage_meta["bucket_fill_ms_medium"], + "bucket_fill_ms_long": stage_meta["bucket_fill_ms_long"], + "bucket_nrows_short": stage_meta["bucket_nrows_short"], + "bucket_nrows_medium": stage_meta["bucket_nrows_medium"], + "bucket_nrows_long": stage_meta["bucket_nrows_long"], + "long_row_sliced_count": stage_meta["long_row_sliced_count"], + } + return result, elapsed_ms, meta + if return_time: + return result, elapsed_ms + if return_meta: + meta = { + "prepare_ms": prepare_ms, + "count_ms": stage_meta["count_ms"], + "fill_ms": stage_meta["fill_ms"], + "hash_capacity_hint": prepared.hash_capacity_hint, + "bucket_ms_short": stage_meta["bucket_ms_short"], + "bucket_ms_medium": stage_meta["bucket_ms_medium"], + "bucket_ms_long": stage_meta["bucket_ms_long"], + "bucket_count_ms_short": stage_meta["bucket_count_ms_short"], + "bucket_count_ms_medium": stage_meta["bucket_count_ms_medium"], + "bucket_count_ms_long": stage_meta["bucket_count_ms_long"], + "bucket_fill_ms_short": stage_meta["bucket_fill_ms_short"], + "bucket_fill_ms_medium": stage_meta["bucket_fill_ms_medium"], + "bucket_fill_ms_long": stage_meta["bucket_fill_ms_long"], + "bucket_nrows_short": stage_meta["bucket_nrows_short"], + "bucket_nrows_medium": stage_meta["bucket_nrows_medium"], + "bucket_nrows_long": stage_meta["bucket_nrows_long"], + "long_row_sliced_count": stage_meta["long_row_sliced_count"], + } + return result, meta + return result + + +def _csr_to_sorted_pairs(data, indices, indptr, n_cols): + n_rows = int(indptr.numel()) - 1 + row_counts = indptr[1:] - indptr[:-1] + rows = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + row_counts, + ) + cols = indices.to(torch.int64) + keys = rows * max(1, int(n_cols)) + cols + if keys.numel() == 0: + return keys, data + order = torch.argsort(keys) + return keys[order], data[order] + + +def _spgemm_pairwise_summary(candidate, reference, value_dtype): + c_data, c_indices, c_indptr, c_shape = candidate + r_data, r_indices, r_indptr, r_shape = reference + if c_shape != r_shape: + return { + "match": False, + "max_abs_error": float("inf"), + "max_relative_error": float("inf"), + "status": f"shape mismatch {c_shape} vs {r_shape}", + } + c_keys, c_vals = _csr_to_sorted_pairs(c_data, c_indices, c_indptr, c_shape[1]) + r_keys, r_vals = _csr_to_sorted_pairs(r_data, r_indices, r_indptr, r_shape[1]) + if c_keys.numel() != r_keys.numel(): + return { + "match": False, + "max_abs_error": float("inf"), + "max_relative_error": float("inf"), + "status": f"nnz mismatch {c_keys.numel()} vs {r_keys.numel()}", + } + if c_keys.numel() > 0 and not torch.equal(c_keys, r_keys): + return { + "match": False, + "max_abs_error": float("inf"), + "max_relative_error": float("inf"), + "status": "sparsity pattern mismatch", + } + if c_vals.numel() == 0: + return { + "match": True, + "max_abs_error": 0.0, + "max_relative_error": 0.0, + "status": "ok", + } + abs_diff = torch.abs(c_vals - r_vals) + max_abs = float(torch.max(abs_diff).item()) + ref_max = float(torch.max(torch.abs(r_vals)).item()) + max_rel = 0.0 if ref_max == 0.0 else max_abs / ref_max + atol, rtol = _tolerance_for_dtype(value_dtype) + match = bool(torch.allclose(c_vals, r_vals, atol=atol, rtol=rtol)) + return { + "match": match, + "max_abs_error": max_abs, + "max_relative_error": max_rel, + "status": "ok" if match else "value mismatch", + } + + +def _to_torch_csr(data, indices, indptr, shape): + return torch.sparse_csr_tensor( + indptr.to(torch.int64), + indices.to(torch.int64), + data, + size=shape, + device=data.device, + ) + + +def _torch_sparse_to_csr(tensor): + if tensor.layout == torch.sparse_csr: + indptr = tensor.crow_indices().to(torch.int64).contiguous() + indices = tensor.col_indices().to(torch.int32).contiguous() + data = tensor.values().contiguous() + shape = (int(tensor.shape[0]), int(tensor.shape[1])) + return data, indices, indptr, shape + if tensor.layout == torch.sparse_coo: + t = tensor.coalesce() + rows = t.indices()[0].to(torch.int64) + cols = t.indices()[1].to(torch.int64) + vals = t.values() + n_rows, n_cols = int(t.shape[0]), int(t.shape[1]) + if rows.numel() == 0: + return ( + torch.empty(0, dtype=vals.dtype, device=vals.device), + torch.empty(0, dtype=torch.int32, device=vals.device), + torch.zeros(n_rows + 1, dtype=torch.int64, device=vals.device), + (n_rows, n_cols), + ) + key = rows * max(1, n_cols) + cols + order = torch.argsort(key) + rows = rows[order] + cols = cols[order] + vals = vals[order] + row_counts = torch.bincount(rows, minlength=n_rows) + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=vals.device) + indptr[1:] = torch.cumsum(row_counts, dim=0) + return vals, cols.to(torch.int32), indptr, (n_rows, n_cols) + raise TypeError(f"Unsupported sparse layout: {tensor.layout}") + + +def benchmark_spgemm_case( + n_rows=1024, + n_inner=1024, + n_cols=1024, + nnz_a=16384, + nnz_b=16384, + value_dtype=torch.float32, + warmup=10, + iters=30, + run_cusparse=True, +): + """Benchmark CSR SpGEMM and compare with torch/cuSPARSE baselines.""" + if value_dtype not in SUPPORTED_SPGEMM_VALUE_DTYPES: + raise TypeError("value_dtype must be torch.float32 or torch.float64") + device = torch.device("cuda") + a_data, a_indices, a_indptr = _build_random_csr( + n_rows, n_inner, nnz_a, value_dtype, torch.int32, device + ) + b_data, b_indices, b_indptr = _build_random_csr( + n_inner, n_cols, nnz_b, value_dtype, torch.int32, device + ) + + prepared = prepare_spgemm_csr( + a_data, a_indices, a_indptr, (n_rows, n_inner), + b_data, b_indices, b_indptr, (n_inner, n_cols), + ) + op = lambda: flagsparse_spgemm_csr(prepared=prepared, return_time=False) + triton_result, triton_ms = _benchmark_cuda_op(op, warmup=warmup, iters=iters) + + a_t = _to_torch_csr(a_data, a_indices, a_indptr, (n_rows, n_inner)) + b_t = _to_torch_csr(b_data, b_indices, b_indptr, (n_inner, n_cols)) + + pytorch_reason = None + pytorch_ms = None + pytorch_result = None + try: + torch_op = lambda: torch.sparse.mm(a_t, b_t) + pytorch_sparse, pytorch_ms = _benchmark_cuda_op(torch_op, warmup=warmup, iters=iters) + pytorch_result = _torch_sparse_to_csr(pytorch_sparse) + except Exception as exc: + pytorch_reason = str(exc) + a_coo = a_t.to_sparse_coo().coalesce() + b_coo = b_t.to_sparse_coo().coalesce() + torch_op = lambda: torch.sparse.mm(a_coo, b_coo) + pytorch_sparse, pytorch_ms = _benchmark_cuda_op(torch_op, warmup=warmup, iters=iters) + pytorch_result = _torch_sparse_to_csr(pytorch_sparse) + + triton_summary = _spgemm_pairwise_summary(triton_result, pytorch_result, value_dtype) + + cusparse_ms = None + cusparse_reason = None + cusparse_match = None + if run_cusparse: + if cp is None or cpx_sparse is None: + cusparse_reason = "CuPy/cuSPARSE is not available" + else: + try: + a_cp = cpx_sparse.csr_matrix( + (_cupy_from_torch(a_data), _cupy_from_torch(a_indices.to(torch.int64)), _cupy_from_torch(a_indptr.to(torch.int64))), + shape=(n_rows, n_inner), + ) + b_cp = cpx_sparse.csr_matrix( + (_cupy_from_torch(b_data), _cupy_from_torch(b_indices.to(torch.int64)), _cupy_from_torch(b_indptr.to(torch.int64))), + shape=(n_inner, n_cols), + ) + c_cp, cusparse_ms = _benchmark_cuda_op(lambda: a_cp @ b_cp, warmup=warmup, iters=iters) + c_coo = c_cp.tocoo() + rows = _torch_from_cupy(c_coo.row).to(torch.int64) + cols = _torch_from_cupy(c_coo.col).to(torch.int64) + vals = _torch_from_cupy(c_coo.data).to(value_dtype) + c_t = torch.sparse_coo_tensor( + torch.stack([rows, cols]), vals, (n_rows, n_cols), device=device + ).coalesce() + c_ref = _torch_sparse_to_csr(c_t) + cusparse_match = _spgemm_pairwise_summary(triton_result, c_ref, value_dtype)["match"] + except Exception as exc: + cusparse_reason = str(exc) + + return { + "parameters": { + "n_rows": n_rows, + "n_inner": n_inner, + "n_cols": n_cols, + "nnz_a": nnz_a, + "nnz_b": nnz_b, + "value_dtype": str(value_dtype), + "warmup": warmup, + "iters": iters, + }, + "performance": { + "triton_ms": triton_ms, + "pytorch_ms": pytorch_ms, + "cusparse_ms": cusparse_ms, + "triton_speedup_vs_pytorch": (pytorch_ms / triton_ms if (pytorch_ms and triton_ms > 0) else None), + "triton_speedup_vs_cusparse": (cusparse_ms / triton_ms if (cusparse_ms and triton_ms > 0) else None), + }, + "verification": { + "triton_match_pytorch": triton_summary["match"], + "triton_max_abs_error": triton_summary["max_abs_error"], + "triton_max_relative_error": triton_summary["max_relative_error"], + "cusparse_match_pytorch": cusparse_match, + }, + "backend_status": { + "pytorch_unavailable_reason": pytorch_reason, + "cusparse_unavailable_reason": cusparse_reason, + }, + "samples": { + "triton": triton_result, + "pytorch": pytorch_result, + }, + } diff --git a/src/flagsparse/sparse_operations/spmm_coo.py b/src/flagsparse/sparse_operations/spmm_coo.py new file mode 100644 index 0000000..1943aef --- /dev/null +++ b/src/flagsparse/sparse_operations/spmm_coo.py @@ -0,0 +1,1172 @@ +"""Native COO SpMM kernels, route helpers, and internal benchmark entry points.""" + +from ._common import * +from .spmm_csr import ( + SUPPORTED_SPMM_VALUE_DTYPES, + _select_spmm_alg1_warp_and_factor, + _spmm_coo_reference_tolerance, + _spmm_relative_threshold, + _spmm_validation_metrics, +) +def _spmm_coo_compute_dtype(value_dtype): + if _is_complex_dtype(value_dtype): + return torch.complex128 if value_dtype == torch.complex64 else value_dtype + if value_dtype in (torch.float16, torch.bfloat16): + return torch.float32 + if value_dtype == torch.float32: + return torch.float64 + return value_dtype + + +def _sort_coo_lex_inplace(data, row, col, n_cols): + row64 = row.to(torch.int64) + col64 = col.to(torch.int64) + if data.numel() == 0: + return data.contiguous(), row64, col64 + key = row64 * max(1, int(n_cols)) + col64 + order = torch.argsort(key) + return ( + data[order].contiguous(), + row64[order].contiguous(), + col64[order].contiguous(), + ) + + +def _coalesce_coo_entries(data, row, col, shape): + """Merge duplicate (row, col) by summing values (PyTorch COO coalesce).""" + n_rows, n_cols = int(shape[0]), int(shape[1]) + if data.numel() == 0: + z = torch.empty(0, dtype=torch.int64, device=data.device) + return data.contiguous(), z, z.clone() + row64 = row.to(torch.int64) + col64 = col.to(torch.int64) + coo = torch.sparse_coo_tensor( + torch.stack([row64, col64]), + data, + size=(n_rows, n_cols), + device=data.device, + dtype=data.dtype, + ).coalesce() + idx = coo.indices() + return coo.values().contiguous(), idx[0].contiguous(), idx[1].contiguous() + + +def _build_torch_sparse_coo(data, row, col, shape): + """Coalesced CUDA COO tensor for ``torch.sparse.mm`` (indices int64).""" + n_rows, n_cols = int(shape[0]), int(shape[1]) + if data.numel() == 0: + empty_idx = torch.empty((2, 0), dtype=torch.int64, device=data.device) + return torch.sparse_coo_tensor( + empty_idx, + data, + size=(n_rows, n_cols), + device=data.device, + dtype=data.dtype, + ) + row_i = row.to(torch.int64) + col_i = col.to(torch.int64) + indices = torch.stack([row_i, col_i]) + return torch.sparse_coo_tensor( + indices, + data, + size=(n_rows, n_cols), + device=data.device, + dtype=data.dtype, + ).coalesce() + + +def _prepare_spmm_coo_canonical_prepared(data, row, col, B, n_rows, n_cols, n_dense_cols): + output_dtype = data.dtype + compute_dtype = _spmm_coo_compute_dtype(output_dtype) + data_compute = data if compute_dtype == output_dtype else data.to(compute_dtype) + B_compute = B if compute_dtype == output_dtype else B.to(compute_dtype) + canonical_data, canonical_row, canonical_col = _coalesce_coo_entries( + data_compute, + row, + col, + (n_rows, n_cols), + ) + canonical_data, canonical_row, canonical_col = _sort_coo_lex_inplace( + canonical_data, + canonical_row, + canonical_col, + n_cols, + ) + return ( + canonical_data, + canonical_row, + canonical_col, + B_compute, + n_rows, + n_cols, + n_dense_cols, + output_dtype, + compute_dtype, + ) + + +def _prepare_spmm_coo_canonical_inputs(data, row, col, B, shape): + data, kernel_row, kernel_col, B, n_rows, n_cols, n_dense_cols = _prepare_spmm_coo_inputs( + data, row, col, B, shape + ) + return _prepare_spmm_coo_canonical_prepared( + data, + kernel_row, + kernel_col, + B, + n_rows, + n_cols, + n_dense_cols, + ) +def _seg_starts_from_sorted_rows(row_i32, nnz, device): + if nnz == 0: + return None + diff = row_i32[1:] != row_i32[:-1] + breaks = torch.nonzero(diff, as_tuple=False).flatten().to(torch.int32) + 1 + return torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + breaks, + torch.tensor([nnz], dtype=torch.int32, device=device), + ] + ) + + +@triton.jit +def _spmm_coo_rowrun_real_kernel( + data_ptr, + row_ptr, + col_ptr, + b_ptr, + c_ptr, + seg_starts_ptr, + n_segs, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, + ACC_DTYPE: tl.constexpr, +): + seg = tl.program_id(0) + pid_n = tl.program_id(1) + if seg >= n_segs: + return + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(seg_starts_ptr + seg) + end = tl.load(seg_starts_ptr + seg + 1) + row_nnz = end - start + row_id = tl.load(row_ptr + start) + acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = idx < end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(col_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + acc = acc + a_val.to(ACC_DTYPE) * b_vals.to(ACC_DTYPE) + + tl.store(c_ptr + row_id * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +@triton.jit +def _spmm_coo_rowrun_complex_kernel( + data_ri_ptr, + row_ptr, + col_ptr, + b_ri_ptr, + c_ri_ptr, + seg_starts_ptr, + n_segs, + n_dense_cols, + stride_bk, + stride_bn, + stride_br, + stride_cm, + stride_cn, + stride_cr, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, + ACC_DTYPE: tl.constexpr, +): + seg = tl.program_id(0) + pid_n = tl.program_id(1) + if seg >= n_segs: + return + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(seg_starts_ptr + seg) + end = tl.load(seg_starts_ptr + seg + 1) + row_nnz = end - start + row_id = tl.load(row_ptr + start) + acc_re = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc_im = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = idx < end + a_re = tl.load(data_ri_ptr + idx * 2, mask=valid, other=0.0) + a_im = tl.load(data_ri_ptr + idx * 2 + 1, mask=valid, other=0.0) + a_col = tl.load(col_ptr + idx, mask=valid, other=0) + b_re = tl.load( + b_ri_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + b_im = tl.load( + b_ri_ptr + a_col * stride_bk + offs_n * stride_bn + stride_br, + mask=mask_n & valid, + other=0.0, + ) + acc_re = acc_re + a_re.to(ACC_DTYPE) * b_re.to(ACC_DTYPE) - a_im.to(ACC_DTYPE) * b_im.to(ACC_DTYPE) + acc_im = acc_im + a_re.to(ACC_DTYPE) * b_im.to(ACC_DTYPE) + a_im.to(ACC_DTYPE) * b_re.to(ACC_DTYPE) + + tl.store(c_ri_ptr + row_id * stride_cm + offs_n * stride_cn, acc_re, mask=mask_n) + tl.store( + c_ri_ptr + row_id * stride_cm + offs_n * stride_cn + stride_cr, + acc_im, + mask=mask_n, + ) + +@triton.jit +def _spmm_coo_atomic_real_kernel( + data_ptr, + row_ptr, + col_ptr, + b_ptr, + c_ptr, + nnz, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + ACC_DTYPE: tl.constexpr, +): + idx = tl.program_id(0) + dense_col = tl.program_id(1) + if idx >= nnz or dense_col >= n_dense_cols: + return + + row_k = tl.load(row_ptr + idx) + col_k = tl.load(col_ptr + idx) + val_k = tl.load(data_ptr + idx) + b_val = tl.load(b_ptr + col_k * stride_bk + dense_col * stride_bn) + tl.atomic_add( + c_ptr + row_k * stride_cm + dense_col * stride_cn, + val_k.to(ACC_DTYPE) * b_val.to(ACC_DTYPE), + sem="relaxed", + ) +@triton.jit +def _spmm_coo_atomic_complex_kernel( + data_ri_ptr, + row_ptr, + col_ptr, + b_ri_ptr, + c_ri_ptr, + nnz, + n_dense_cols, + stride_bk, + stride_bn, + stride_br, + stride_cm, + stride_cn, + stride_cr, + ACC_DTYPE: tl.constexpr, +): + idx = tl.program_id(0) + dense_col = tl.program_id(1) + if idx >= nnz or dense_col >= n_dense_cols: + return + + row_k = tl.load(row_ptr + idx) + col_k = tl.load(col_ptr + idx) + a_re = tl.load(data_ri_ptr + idx * 2) + a_im = tl.load(data_ri_ptr + idx * 2 + 1) + b_re = tl.load(b_ri_ptr + col_k * stride_bk + dense_col * stride_bn) + b_im = tl.load(b_ri_ptr + col_k * stride_bk + dense_col * stride_bn + stride_br) + contrib_re = a_re.to(ACC_DTYPE) * b_re.to(ACC_DTYPE) - a_im.to(ACC_DTYPE) * b_im.to(ACC_DTYPE) + contrib_im = a_re.to(ACC_DTYPE) * b_im.to(ACC_DTYPE) + a_im.to(ACC_DTYPE) * b_re.to(ACC_DTYPE) + tl.atomic_add( + c_ri_ptr + row_k * stride_cm + dense_col * stride_cn, + contrib_re, + sem="relaxed", + ) + tl.atomic_add( + c_ri_ptr + row_k * stride_cm + dense_col * stride_cn + stride_cr, + contrib_im, + sem="relaxed", + ) +def _prepare_spmm_coo_inputs(data, row, col, B, shape): + if len(shape) != 2: + raise ValueError("shape must be a 2-tuple: (n_rows, n_cols)") + if data.ndim != 1 or row.ndim != 1 or col.ndim != 1: + raise ValueError("data, row, and col must be 1D tensors") + if B.ndim != 2: + raise ValueError("B must be a 2D dense tensor") + + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows < 0 or n_cols < 0: + raise ValueError("shape dimensions must be non-negative") + if data.numel() != row.numel() or data.numel() != col.numel(): + raise ValueError("data, row, and col must have the same length (nnz)") + if B.shape[0] != n_cols: + raise ValueError(f"B.shape[0] must be n_cols={n_cols}, got {B.shape[0]}") + + if not all(t.is_cuda for t in (data, row, col, B)): + raise ValueError("data, row, col, and B must be CUDA tensors") + if not all(t.device == data.device for t in (row, col, B)): + raise ValueError("data, row, col, and B must be on the same CUDA device") + if data.dtype not in SUPPORTED_SPMM_VALUE_DTYPES: + raise TypeError( + "data dtype must be one of: float16, bfloat16, float32, float64, complex64, complex128" + ) + if B.dtype != data.dtype: + raise TypeError("B dtype must match data dtype") + if row.dtype not in SUPPORTED_INDEX_DTYPES: + raise TypeError("row dtype must be torch.int32 or torch.int64") + if col.dtype not in SUPPORTED_INDEX_DTYPES: + raise TypeError("col dtype must be torch.int32 or torch.int64") + + nnz = data.numel() + if nnz > _INDEX_LIMIT_INT32: + raise ValueError("nnz exceeds the int32 range supported by the Triton COO kernel") + if nnz > 0: + min_row = int(row.min().item()) + max_row = int(row.max().item()) + min_col = int(col.min().item()) + max_col = int(col.max().item()) + if min_row < 0 or max_row >= n_rows: + raise IndexError("row indices out of range for n_rows") + if min_col < 0 or max_col >= n_cols: + raise IndexError("col indices out of range for n_cols") + if max_row > _INDEX_LIMIT_INT32: + raise ValueError( + "row indices exceed the int32 range supported by the Triton kernel" + ) + if max_col > _INDEX_LIMIT_INT32: + raise ValueError( + "column indices exceed the int32 range supported by the Triton kernel" + ) + + data = data.contiguous() + row = row.contiguous() + col = col.contiguous() + B = B.contiguous() + + kernel_row = row.to(torch.int32) if row.dtype == torch.int64 else row + kernel_col = col.to(torch.int32) if col.dtype == torch.int64 else col + return data, kernel_row, kernel_col, B, n_rows, n_cols, int(B.shape[1]) + + +def _resolve_spmm_coo_launch_config(n_dense_cols, nnz, block_n=None, block_nnz=256): + warp_size, factor = _select_spmm_alg1_warp_and_factor(n_dense_cols) + + if block_n is None: + block_n = warp_size * factor + if block_nnz is None: + block_nnz = 256 + + if block_n <= 0 or block_nnz <= 0: + raise ValueError("block_n and block_nnz must be positive when provided") + + return { + "block_n": int(block_n), + "block_nnz": int(block_nnz), + "required_nnz_tiles": int(triton.cdiv(nnz, block_nnz) if nnz > 0 else 0), + "heuristic_warp_size": int(warp_size), + "heuristic_factor": int(factor), + } + + +def _triton_spmm_coo_rowrun_impl( + data, + row, + col, + B, + n_rows, + n_dense_cols, + block_n, + block_nnz, + output_dtype, +): + device = data.device + dtype = data.dtype + if n_rows == 0 or n_dense_cols == 0 or B.shape[0] == 0 or data.numel() == 0: + return torch.zeros((n_rows, n_dense_cols), dtype=output_dtype, device=device) + + seg_starts = _seg_starts_from_sorted_rows(row, int(data.numel()), device) + n_segs = int(seg_starts.numel()) - 1 if seg_starts is not None else 0 + if n_segs == 0: + return torch.zeros((n_rows, n_dense_cols), dtype=output_dtype, device=device) + + grid = (n_segs, triton.cdiv(n_dense_cols, block_n)) + if not _is_complex_dtype(dtype): + C_compute = torch.zeros((n_rows, n_dense_cols), dtype=dtype, device=device) + acc_dtype = tl.float64 if dtype == torch.float64 else tl.float32 + _spmm_coo_rowrun_real_kernel[grid]( + data, + row, + col, + B, + C_compute, + seg_starts, + n_segs, + n_dense_cols, + B.stride(0), + B.stride(1), + C_compute.stride(0), + C_compute.stride(1), + BLOCK_N=block_n, + BLOCK_NNZ=block_nnz, + ACC_DTYPE=acc_dtype, + ) + return C_compute if dtype == output_dtype else C_compute.to(output_dtype) + + data_ri = torch.view_as_real(data).contiguous().reshape(-1) + B_ri = torch.view_as_real(B).contiguous() + C_ri = torch.zeros((n_rows, n_dense_cols, 2), dtype=B_ri.dtype, device=device) + acc_dtype = tl.float64 if B_ri.dtype == torch.float64 else tl.float32 + _spmm_coo_rowrun_complex_kernel[grid]( + data_ri, + row, + col, + B_ri, + C_ri, + seg_starts, + n_segs, + n_dense_cols, + B_ri.stride(0), + B_ri.stride(1), + B_ri.stride(2), + C_ri.stride(0), + C_ri.stride(1), + C_ri.stride(2), + BLOCK_N=block_n, + BLOCK_NNZ=block_nnz, + ACC_DTYPE=acc_dtype, + ) + C = torch.view_as_complex(C_ri.contiguous()) + return C if dtype == output_dtype else C.to(output_dtype) + +def _triton_spmm_coo_atomic_impl( + data, + row, + col, + B, + n_rows, + n_dense_cols, + block_n, + block_nnz, + output_dtype, +): + device = data.device + dtype = data.dtype + if n_rows == 0 or n_dense_cols == 0 or B.shape[0] == 0 or data.numel() == 0: + return torch.zeros((n_rows, n_dense_cols), dtype=output_dtype, device=device) + + nnz = int(data.numel()) + if nnz == 0: + return torch.zeros((n_rows, n_dense_cols), dtype=output_dtype, device=device) + + if not _is_complex_dtype(dtype): + C_compute = torch.zeros((n_rows, n_dense_cols), dtype=dtype, device=device) + acc_dtype = tl.float64 if dtype == torch.float64 else tl.float32 + _spmm_coo_atomic_real_kernel[(nnz, n_dense_cols)]( + data, + row, + col, + B, + C_compute, + nnz, + n_dense_cols, + B.stride(0), + B.stride(1), + C_compute.stride(0), + C_compute.stride(1), + ACC_DTYPE=acc_dtype, + ) + return C_compute if dtype == output_dtype else C_compute.to(output_dtype) + + data_ri = torch.view_as_real(data).contiguous().reshape(-1) + B_ri = torch.view_as_real(B).contiguous() + C_ri = torch.zeros((n_rows, n_dense_cols, 2), dtype=B_ri.dtype, device=device) + acc_dtype = tl.float64 if B_ri.dtype == torch.float64 else tl.float32 + _spmm_coo_atomic_complex_kernel[(nnz, n_dense_cols)]( + data_ri, + row, + col, + B_ri, + C_ri, + nnz, + n_dense_cols, + B_ri.stride(0), + B_ri.stride(1), + B_ri.stride(2), + C_ri.stride(0), + C_ri.stride(1), + C_ri.stride(2), + ACC_DTYPE=acc_dtype, + ) + C = torch.view_as_complex(C_ri.contiguous()) + return C if dtype == output_dtype else C.to(output_dtype) + +def _normalize_spmm_coo_route(route): + route = "rowrun" if route is None else str(route).lower() + if route not in ("rowrun", "atomic"): + raise ValueError("route must be 'rowrun' or 'atomic'") + return route + + +def _triton_spmm_coo_impl( + data, + row, + col, + B, + n_rows, + n_dense_cols, + block_n, + block_nnz, + route="rowrun", + output_dtype=None, +): + route = _normalize_spmm_coo_route(route) + resolved_output_dtype = output_dtype if output_dtype is not None else data.dtype + if route == "rowrun": + return _triton_spmm_coo_rowrun_impl( + data, + row, + col, + B, + n_rows, + n_dense_cols, + block_n, + block_nnz, + output_dtype=resolved_output_dtype, + ) + return _triton_spmm_coo_atomic_impl( + data, + row, + col, + B, + n_rows, + n_dense_cols, + block_n, + block_nnz, + output_dtype=resolved_output_dtype, + ) + + +def _run_spmm_coo_canonical_route( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_dense_cols, + output_dtype, + block_n=None, + block_nnz=256, + out=None, + return_time=False, + route="rowrun", +): + route = _normalize_spmm_coo_route(route) + launch = _resolve_spmm_coo_launch_config( + n_dense_cols, + canonical_data.numel(), + block_n=block_n, + block_nnz=block_nnz, + ) + + if out is not None: + if not out.is_cuda: + raise ValueError("out must be a CUDA tensor") + if out.device != canonical_data.device: + raise ValueError("out must be on the same CUDA device as the inputs") + if out.shape != (n_rows, n_dense_cols) or out.dtype != output_dtype: + raise ValueError("out shape/dtype must match result") + + torch.cuda.synchronize() + t0 = time.perf_counter() + C = _triton_spmm_coo_impl( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_dense_cols, + block_n=launch["block_n"], + block_nnz=launch["block_nnz"], + route=route, + output_dtype=output_dtype, + ) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + + if out is not None: + out.copy_(C) + C = out + if return_time: + return C, elapsed_ms + return C + +def _run_spmm_coo_route( + data, + row, + col, + B, + shape, + block_n=None, + block_nnz=256, + out=None, + return_time=False, + route="rowrun", +): + route = _normalize_spmm_coo_route(route) + if block_n is not None and block_n <= 0: + raise ValueError("block_n must be positive when provided") + if block_nnz is not None and block_nnz <= 0: + raise ValueError("block_nnz must be positive when provided") + + ( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + _, + n_dense_cols, + output_dtype, + _, + ) = _prepare_spmm_coo_canonical_inputs(data, row, col, B, shape) + + return _run_spmm_coo_canonical_route( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_dense_cols, + output_dtype, + block_n=block_n, + block_nnz=block_nnz, + out=out, + return_time=return_time, + route=route, + ) + +def flagsparse_spmm_coo( + data, + row, + col, + B, + shape, + block_n=None, + block_nnz=256, + out=None, + return_time=False, +): + """COO SpMM: C = A @ B using a native Triton COO row-run kernel by default.""" + return _run_spmm_coo_route( + data, + row, + col, + B, + shape, + block_n=block_n, + block_nnz=block_nnz, + out=out, + return_time=return_time, + route="rowrun", + ) + + +def _build_spmm_coo_pytorch_reference_from_canonical( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + shape, + output_dtype, +): + canonical_coo = _build_torch_sparse_coo( + canonical_data, + canonical_row, + canonical_col, + shape, + ) + expected = torch.sparse.mm(canonical_coo, canonical_B) + return expected if expected.dtype == output_dtype else expected.to(output_dtype) + + + +def _build_spmm_coo_pytorch_reference(data, row, col, B, shape): + native_data, native_row, native_col, native_B, n_rows, n_cols, n_dense_cols = _prepare_spmm_coo_inputs( + data, row, col, B, shape + ) + ( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + _, + _, + _, + output_dtype, + _, + ) = _prepare_spmm_coo_canonical_prepared( + native_data, + native_row, + native_col, + native_B, + n_rows, + n_cols, + n_dense_cols, + ) + native_coo = _build_torch_sparse_coo(native_data, native_row, native_col, shape) + pytorch_format = "COO" + pytorch_reason = None + pytorch_op = lambda: torch.sparse.mm(native_coo, native_B) + expected = _build_spmm_coo_pytorch_reference_from_canonical( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + shape, + output_dtype, + ) + return expected, pytorch_op, pytorch_format, pytorch_reason + + +def _benchmark_spmm_coo_canonical_route( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_dense_cols, + output_dtype, + warmup, + iters, + block_n, + block_nnz, + route, +): + route = _normalize_spmm_coo_route(route) + op = lambda: _run_spmm_coo_canonical_route( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_dense_cols, + output_dtype, + block_n=block_n, + block_nnz=block_nnz, + return_time=False, + route=route, + ) + + torch.cuda.synchronize() + t0 = time.perf_counter() + _ = op() + torch.cuda.synchronize() + first_call_ms = (time.perf_counter() - t0) * 1000.0 + values, steady_ms = _benchmark_cuda_op(op, warmup=warmup, iters=iters) + return values, steady_ms, first_call_ms + + + +def _benchmark_spmm_coo_route( + data, + row, + col, + B, + shape, + warmup, + iters, + block_n, + block_nnz, + route, +): + ( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + _, + n_dense_cols, + output_dtype, + _, + ) = _prepare_spmm_coo_canonical_inputs(data, row, col, B, shape) + return _benchmark_spmm_coo_canonical_route( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_dense_cols, + output_dtype, + warmup, + iters, + block_n, + block_nnz, + route, + ) +def _spmm_coo_pairwise_summary(candidate, reference, value_dtype): + metrics = _spmm_validation_metrics(candidate, reference) + atol, rtol = _spmm_coo_reference_tolerance(value_dtype) + if candidate.numel() == 0: + error_ratio = 0.0 + else: + diff = torch.abs(candidate - reference) + denom = atol + rtol * torch.abs(reference) + error_ratio = float(torch.max(diff / denom).item()) + return { + "match": torch.allclose(candidate, reference, atol=atol, rtol=rtol), + "error_ratio": error_ratio, + "max_abs_error": metrics["max_abs_error"], + "max_relative_error": metrics["max_relative_error"], + "sum_relative_error": metrics["sum_relative_error"], + } + + +def benchmark_spmm_coo_case( + n_rows=4096, + n_cols=4096, + nnz=65536, + n_dense_cols=32, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + block_n=None, + block_nnz=256, + run_cusparse=True, + route="rowrun", + compare_routes=False, +): + """Benchmark native COO SpMM vs PyTorch COO sparse.mm and CuPy/cuSPARSE COO @ dense.""" + selected_route = _normalize_spmm_coo_route(route) + device = torch.device("cuda") + data, row, col = _build_random_coo( + n_rows, n_cols, nnz, value_dtype, index_dtype, device + ) + B = _build_random_dense((n_cols, n_dense_cols), value_dtype, device) + shape = (n_rows, n_cols) + + native_data, native_row, native_col, native_B, _, _, _ = _prepare_spmm_coo_inputs( + data, row, col, B, shape + ) + ( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_cols, + n_dense_cols, + output_dtype, + _, + ) = _prepare_spmm_coo_canonical_prepared( + native_data, + native_row, + native_col, + native_B, + n_rows, + n_cols, + native_B.shape[1], + ) + launch = _resolve_spmm_coo_launch_config( + n_dense_cols, + canonical_data.numel(), + block_n=block_n, + block_nnz=block_nnz, + ) + seg_starts = _seg_starts_from_sorted_rows(canonical_row, canonical_data.numel(), device) + n_row_runs = int(seg_starts.numel()) - 1 if seg_starts is not None else 0 + + expected = _build_spmm_coo_pytorch_reference_from_canonical( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + shape, + output_dtype, + ) + pytorch_coo = _build_torch_sparse_coo(native_data, native_row, native_col, shape) + pytorch_op = lambda: torch.sparse.mm(pytorch_coo, native_B) + pytorch_format = "COO" + pytorch_reason = None + + triton_C, triton_ms, triton_first_call_ms = _benchmark_spmm_coo_canonical_route( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_dense_cols, + output_dtype, + warmup, + iters, + launch["block_n"], + launch["block_nnz"], + selected_route, + ) + triton_summary = _spmm_coo_pairwise_summary(triton_C, expected, value_dtype) + triton_match = triton_summary["match"] + + pytorch_values = expected + pytorch_ms = None + try: + pytorch_values, pytorch_ms = _benchmark_cuda_op( + pytorch_op, warmup=warmup, iters=iters + ) + except Exception as exc: + pytorch_reason = str(exc) if pytorch_reason is None else f"{pytorch_reason}; timing: {exc}" + + cusparse_ms = None + cusparse_match = None + cusparse_reason = None + cusparse_values = None + cusparse_summary = None + _cupy_supported_dtypes = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ) + if run_cusparse: + if cp is None or cpx_sparse is None: + cusparse_reason = "CuPy/cuSPARSE is not available" + elif value_dtype not in _cupy_supported_dtypes: + cusparse_reason = "float16/bfloat16 not supported by CuPy sparse; skipped" + else: + try: + data_cp = _cupy_from_torch(native_data) + row_cp = _cupy_from_torch(native_row.to(torch.int64)) + col_cp = _cupy_from_torch(native_col.to(torch.int64)) + B_cp = _cupy_from_torch(native_B) + A_coo = cpx_sparse.coo_matrix( + (data_cp, (row_cp, col_cp)), shape=shape + ) + cusparse_values_cp, cusparse_ms = _benchmark_cuda_op( + lambda: A_coo @ B_cp, warmup=warmup, iters=iters + ) + cusparse_values = _torch_from_cupy(cusparse_values_cp) + cusparse_summary = _spmm_coo_pairwise_summary(cusparse_values, expected, value_dtype) + cusparse_match = cusparse_summary["match"] + except Exception as exc: + cusparse_reason = str(exc) + + route_results = None + parity = None + route_samples = None + if compare_routes: + route_outputs = {selected_route: triton_C} + route_results = { + selected_route: { + "route": selected_route, + "ms": triton_ms, + "first_call_ms": triton_first_call_ms, + "match_reference": triton_summary["match"], + "error_ratio": triton_summary["error_ratio"], + "max_abs_error": triton_summary["max_abs_error"], + "max_relative_error": triton_summary["max_relative_error"], + "match_cusparse": ( + None if cusparse_values is None else torch.allclose( + triton_C, + cusparse_values, + atol=_spmm_coo_reference_tolerance(value_dtype)[0], + rtol=_spmm_coo_reference_tolerance(value_dtype)[1], + ) + ), + "error": None, + } + } + + for extra_route in ("rowrun", "atomic"): + if extra_route in route_outputs: + continue + try: + extra_values, extra_ms, extra_first_call_ms = _benchmark_spmm_coo_canonical_route( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + n_rows, + n_dense_cols, + output_dtype, + warmup, + iters, + launch["block_n"], + launch["block_nnz"], + extra_route, + ) + extra_summary = _spmm_coo_pairwise_summary(extra_values, expected, value_dtype) + route_outputs[extra_route] = extra_values + route_results[extra_route] = { + "route": extra_route, + "ms": extra_ms, + "first_call_ms": extra_first_call_ms, + "match_reference": extra_summary["match"], + "error_ratio": extra_summary["error_ratio"], + "max_abs_error": extra_summary["max_abs_error"], + "max_relative_error": extra_summary["max_relative_error"], + "match_cusparse": ( + None if cusparse_values is None else torch.allclose( + extra_values, + cusparse_values, + atol=_spmm_coo_reference_tolerance(value_dtype)[0], + rtol=_spmm_coo_reference_tolerance(value_dtype)[1], + ) + ), + "error": None, + } + except Exception as exc: + route_results[extra_route] = { + "route": extra_route, + "ms": None, + "first_call_ms": None, + "match_reference": False, + "error_ratio": None, + "max_abs_error": None, + "max_relative_error": None, + "match_cusparse": None, + "error": str(exc), + } + + def _safe_parity(lhs, rhs): + if lhs in route_outputs and rhs in route_outputs: + return _spmm_coo_pairwise_summary(route_outputs[lhs], route_outputs[rhs], value_dtype) + return { + "match": None, + "error_ratio": None, + "max_abs_error": None, + "max_relative_error": None, + "sum_relative_error": None, + } + + parity = { + "rowrun_vs_atomic": _safe_parity("rowrun", "atomic"), + } + route_samples = route_outputs + triton_speedup_vs_pytorch = ( + pytorch_ms / triton_ms if (pytorch_ms is not None and triton_ms > 0) else None + ) + triton_speedup_vs_cusparse = ( + cusparse_ms / triton_ms if (cusparse_ms is not None and triton_ms > 0) else None + ) + threshold = _spmm_relative_threshold(value_dtype) + return { + "parameters": { + "format": "coo", + "internal_format": f"native-{selected_route}", + "route": selected_route, + "compare_routes": bool(compare_routes), + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": nnz, + "n_dense_cols": n_dense_cols, + "value_dtype": str(value_dtype), + "index_dtype": str(index_dtype), + "warmup": warmup, + "iters": iters, + "block_n": launch["block_n"], + "block_nnz": launch["block_nnz"], + "required_nnz_tiles": launch["required_nnz_tiles"], + "heuristic_warp_size": launch["heuristic_warp_size"], + "heuristic_factor": launch["heuristic_factor"], + "n_row_runs": n_row_runs, + "run_cusparse": run_cusparse, + }, + "performance": { + "pytorch_ms": pytorch_ms, + "triton_ms": triton_ms, + "triton_first_call_ms": triton_first_call_ms, + "cusparse_ms": cusparse_ms, + "triton_speedup_vs_pytorch": triton_speedup_vs_pytorch, + "triton_speedup_vs_cusparse": triton_speedup_vs_cusparse, + }, + "verification": { + "triton_match_reference": triton_match, + "triton_match_pytorch": triton_match, + "triton_max_error": triton_summary["max_abs_error"], + "triton_max_abs_error": triton_summary["max_abs_error"], + "triton_max_relative_error": triton_summary["max_relative_error"], + "triton_sum_relative_error": triton_summary["sum_relative_error"], + "triton_relative_threshold": threshold, + "triton_strict_allclose_match": triton_match, + "pytorch_match_reference": True, + "pytorch_max_error": 0.0, + "pytorch_max_abs_error": 0.0, + "pytorch_max_relative_error": 0.0, + "pytorch_sum_relative_error": 0.0, + "pytorch_relative_threshold": threshold, + "cusparse_match_reference": cusparse_match, + "cusparse_match_pytorch": cusparse_match, + "cusparse_max_error": (cusparse_summary["max_abs_error"] if cusparse_summary is not None else None), + "cusparse_max_abs_error": (cusparse_summary["max_abs_error"] if cusparse_summary is not None else None), + "cusparse_max_relative_error": (cusparse_summary["max_relative_error"] if cusparse_summary is not None else None), + "cusparse_sum_relative_error": (cusparse_summary["sum_relative_error"] if cusparse_summary is not None else None), + "cusparse_relative_threshold": threshold, + "cusparse_strict_allclose_match": cusparse_match, + }, + "backend_status": { + "pytorch_unavailable_reason": pytorch_reason, + "pytorch_sparse_format": pytorch_format, + "cusparse_unavailable_reason": cusparse_reason, + "flagsparse_internal_route": f"coo-native-{selected_route}", + }, + "samples": { + "pytorch": pytorch_values, + "triton": triton_C, + "reference": expected, + "cusparse": cusparse_values, + }, + "route_results": route_results, + "parity": parity, + "route_samples": route_samples, + } + +def comprehensive_spmm_coo_test( + n_rows=4096, + n_cols=4096, + nnz=65536, + n_dense_cols=32, + dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + block_n=None, + block_nnz=256, + run_cusparse=True, +): + """Full COO SpMM benchmark entry for one configuration.""" + return benchmark_spmm_coo_case( + n_rows=n_rows, + n_cols=n_cols, + nnz=nnz, + n_dense_cols=n_dense_cols, + value_dtype=dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + block_n=block_n, + block_nnz=block_nnz, + run_cusparse=run_cusparse, + ) \ No newline at end of file diff --git a/src/flagsparse/sparse_operations/spmm_csr.py b/src/flagsparse/sparse_operations/spmm_csr.py new file mode 100644 index 0000000..c9f767c --- /dev/null +++ b/src/flagsparse/sparse_operations/spmm_csr.py @@ -0,0 +1,2394 @@ +"""CSR SpMM kernels, helpers, and benchmark entry points.""" + +from ._common import * + +SUPPORTED_SPMM_VALUE_DTYPES = SUPPORTED_VALUE_DTYPES +def _spmm_relative_threshold(value_dtype): + if value_dtype == torch.float16: + return 5e-3 + if value_dtype == torch.bfloat16: + return 1e-2 + if value_dtype in (torch.float32, torch.complex64): + return 1e-6 + if value_dtype in (torch.float64, torch.complex128): + return 1e-12 + return 1e-6 + + +def _spmm_coo_reference_tolerance(value_dtype): + if value_dtype == torch.float16: + return 2e-3, 2e-3 + if value_dtype == torch.bfloat16: + return 1e-1, 1e-1 + if value_dtype in (torch.float32, torch.complex64): + return 1e-4, 1e-2 + if value_dtype in (torch.float64, torch.complex128): + return 1e-12, 1e-10 + return 1e-6, 1e-5 + + +def _spmm_error_metrics(candidate, reference): + if candidate.shape != reference.shape: + raise ValueError( + f"candidate and reference must have the same shape, got {candidate.shape} vs {reference.shape}" + ) + + if candidate.numel() == 0: + return { + "max_abs_error": 0.0, + "max_relative_error": 0.0, + "sum_relative_error": 0.0, + "reference_max_magnitude": 0.0, + "reference_sum_magnitude": 0.0, + } + + if _is_complex_dtype(reference.dtype): + candidate_compare = torch.abs(candidate) + reference_compare = torch.abs(reference) + abs_diff = torch.abs(candidate_compare - reference_compare) + else: + reference_compare = torch.abs(reference) + abs_diff = torch.abs(candidate - reference) + + max_abs_error = float(torch.max(abs_diff).item()) + reference_max_magnitude = float(torch.max(reference_compare).item()) + sum_abs_error = float(torch.sum(abs_diff).item()) + reference_sum_magnitude = float(torch.sum(reference_compare).item()) + + if reference_max_magnitude == 0.0: + max_relative_error = 0.0 if max_abs_error == 0.0 else float("inf") + else: + max_relative_error = max_abs_error / reference_max_magnitude + + if reference_sum_magnitude == 0.0: + sum_relative_error = 0.0 if sum_abs_error == 0.0 else float("inf") + else: + sum_relative_error = sum_abs_error / reference_sum_magnitude + + return { + "max_abs_error": max_abs_error, + "max_relative_error": max_relative_error, + "sum_relative_error": sum_relative_error, + "reference_max_magnitude": reference_max_magnitude, + "reference_sum_magnitude": reference_sum_magnitude, + } + + +def _spmm_validation_metrics(candidate, reference): + metrics = _spmm_error_metrics(candidate, reference) + threshold = _spmm_relative_threshold(reference.dtype) + atol, rtol = _tolerance_for_dtype(reference.dtype) + metrics.update( + { + "relative_threshold": threshold, + "matches_threshold": metrics["max_relative_error"] <= threshold, + "strict_allclose_match": torch.allclose( + candidate, reference, atol=atol, rtol=rtol + ), + } + ) + return metrics + +@triton.jit +def _spmm_csr_real_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + n_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, + ACC_DTYPE: tl.constexpr, +): + row = tl.program_id(0) + pid_n = tl.program_id(1) + if row >= n_rows: + return + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = idx < end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + acc = acc + a_val.to(ACC_DTYPE) * b_vals.to(ACC_DTYPE) + + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +# Complex-path variant of the same AlphaSparse CSR ALG1 mapping. +@triton.jit +def _spmm_csr_complex_kernel( + data_ri_ptr, + indices_ptr, + indptr_ptr, + b_ri_ptr, + c_ri_ptr, + n_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_br, + stride_cm, + stride_cn, + stride_cr, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, + ACC_DTYPE: tl.constexpr, +): + row = tl.program_id(0) + pid_n = tl.program_id(1) + if row >= n_rows: + return + + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + row_nnz = end - start + acc_re = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc_im = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = idx < end + a_re = tl.load(data_ri_ptr + idx * 2, mask=valid, other=0.0) + a_im = tl.load(data_ri_ptr + idx * 2 + 1, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_re = tl.load( + b_ri_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + b_im = tl.load( + b_ri_ptr + a_col * stride_bk + offs_n * stride_bn + stride_br, + mask=mask_n & valid, + other=0.0, + ) + acc_re = acc_re + a_re.to(ACC_DTYPE) * b_re.to(ACC_DTYPE) - a_im.to(ACC_DTYPE) * b_im.to(ACC_DTYPE) + acc_im = acc_im + a_re.to(ACC_DTYPE) * b_im.to(ACC_DTYPE) + a_im.to(ACC_DTYPE) * b_re.to(ACC_DTYPE) + + tl.store(c_ri_ptr + row * stride_cm + offs_n * stride_cn, acc_re, mask=mask_n) + tl.store( + c_ri_ptr + row * stride_cm + offs_n * stride_cn + stride_cr, + acc_im, + mask=mask_n, + ) + +def _prepare_spmm_csr_matrix(data, indices, indptr, shape): + if len(shape) != 2: + raise ValueError("shape must be a 2-tuple: (n_rows, n_cols)") + if data.ndim != 1 or indices.ndim != 1 or indptr.ndim != 1: + raise ValueError("data, indices, and indptr must be 1D tensors") + + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows < 0 or n_cols < 0: + raise ValueError("shape dimensions must be non-negative") + if indptr.numel() != n_rows + 1: + raise ValueError( + f"indptr length must be n_rows+1={n_rows + 1}, got {indptr.numel()}" + ) + if data.numel() != indices.numel(): + raise ValueError("data and indices must have the same length (nnz)") + + if not all(t.is_cuda for t in (data, indices, indptr)): + raise ValueError("data, indices, and indptr must be CUDA tensors") + if not all(t.device == data.device for t in (indices, indptr)): + raise ValueError("data, indices, and indptr must be on the same CUDA device") + if data.dtype not in SUPPORTED_SPMM_VALUE_DTYPES: + raise TypeError( + "data dtype must be one of: float16, bfloat16, float32, float64, complex64, complex128" + ) + if indices.dtype not in SUPPORTED_INDEX_DTYPES: + raise TypeError("indices dtype must be torch.int32 or torch.int64") + if indptr.dtype not in SUPPORTED_INDEX_DTYPES: + raise TypeError("indptr dtype must be torch.int32 or torch.int64") + + nnz = data.numel() + if indptr.numel() > 0 and int(indptr[0].item()) != 0: + raise ValueError("indptr[0] must be 0") + if indptr.numel() > 0 and int(indptr[-1].item()) != nnz: + raise ValueError(f"indptr[-1] must equal nnz={nnz}, got {int(indptr[-1].item())}") + if indptr.numel() > 1 and bool(torch.any(indptr[1:] < indptr[:-1]).item()): + raise ValueError("indptr must be nondecreasing") + if nnz > 0: + min_col = int(indices.min().item()) + max_col = int(indices.max().item()) + if min_col < 0 or max_col >= n_cols: + raise IndexError("indices out of range for n_cols") + if max_col > _INDEX_LIMIT_INT32: + raise ValueError( + "column indices exceed the int32 range supported by the Triton kernel" + ) + + data = data.contiguous() + indices = indices.contiguous() + indptr = indptr.contiguous() + + kernel_indices = indices.to(torch.int32) if indices.dtype == torch.int64 else indices + kernel_indptr = indptr.to(torch.int64) + row_lengths = ( + kernel_indptr[1:] - kernel_indptr[:-1] + if n_rows > 0 + else kernel_indptr.new_empty((0,)) + ) + max_row_nnz = int(row_lengths.max().item()) if n_rows > 0 else 0 + return data, kernel_indices, kernel_indptr, n_rows, n_cols, row_lengths, max_row_nnz + + +def _prepare_spmm_csr_inputs(data, indices, indptr, B, shape): + if B.ndim != 2: + raise ValueError("B must be a 2D dense tensor") + + ( + data, + kernel_indices, + kernel_indptr, + n_rows, + n_cols, + _row_lengths, + _max_row_nnz, + ) = _prepare_spmm_csr_matrix(data, indices, indptr, shape) + if B.shape[0] != n_cols: + raise ValueError(f"B.shape[0] must be n_cols={n_cols}, got {B.shape[0]}") + if not B.is_cuda: + raise ValueError("B must be a CUDA tensor") + if B.device != data.device: + raise ValueError("B must be on the same CUDA device as the sparse matrix") + if B.dtype != data.dtype: + raise TypeError("B dtype must match data dtype") + + B = B.contiguous() + return data, kernel_indices, kernel_indptr, B, n_rows, n_cols, int(B.shape[1]) + + +def _select_spmm_alg1_warp_and_factor(n_dense_cols): + # Mirrors AlphaSparse CSR ALG1 row-major heuristics without exposing warp details publicly. + if n_dense_cols > 64: + return 32, 4 + if n_dense_cols > 32: + return 32, 2 + if n_dense_cols > 16: + return 32, 1 + if n_dense_cols > 8: + return 16, 1 + if n_dense_cols > 4: + return 8, 1 + return 4, 1 + + +def _resolve_spmm_alg1_launch_config( + n_dense_cols, + max_row_nnz, + block_n=None, + block_nnz=None, + max_segments=None, +): + warp_size, factor = _select_spmm_alg1_warp_and_factor(n_dense_cols) + + if block_n is None: + block_n = warp_size * factor + if block_nnz is None: + block_nnz = warp_size + + if block_n <= 0 or block_nnz <= 0: + raise ValueError("block_n and block_nnz must be positive when provided") + if max_segments is not None and max_segments <= 0: + raise ValueError("max_segments must be positive when provided") + + required_segments = triton.cdiv(max_row_nnz, block_nnz) if max_row_nnz > 0 else 0 + if max_segments is not None and required_segments > int(max_segments): + raise ValueError( + "row nnz requires more CSR segments than the explicit max_segments override allows: " + f"required {required_segments}, provided {int(max_segments)}" + ) + + return { + "block_n": int(block_n), + "block_nnz": int(block_nnz), + "max_segments": (None if max_segments is None else int(max_segments)), + "required_segments": int(required_segments), + "warp_size": int(warp_size), + "factor": int(factor), + "max_row_nnz": int(max_row_nnz), + "auto_max_segments": max_segments is None, + } + + +class PreparedCsrSpmmOpt: + """Cached CSR metadata for repeated native SpMM-opt calls on the same sparse matrix.""" + + __slots__ = ( + "data", + "kernel_indices", + "kernel_indptr", + "shape", + "n_rows", + "n_cols", + "row_lengths", + "max_row_nnz", + "row_buckets", + "supports_opt", + "long_part_rows", + "long_part_starts", + "long_part_ends", + "long_row_ids", + "long_row_part_ptr", + "long_row_fallback_only", + ) + + def __init__( + self, + data, + kernel_indices, + kernel_indptr, + shape, + n_rows, + n_cols, + row_lengths, + max_row_nnz, + row_buckets, + long_part_rows, + long_part_starts, + long_part_ends, + long_row_ids, + long_row_part_ptr, + long_row_fallback_only, + ): + self.data = data + self.kernel_indices = kernel_indices + self.kernel_indptr = kernel_indptr + self.shape = (int(shape[0]), int(shape[1])) + self.n_rows = int(n_rows) + self.n_cols = int(n_cols) + self.row_lengths = row_lengths + self.max_row_nnz = int(max_row_nnz) + self.row_buckets = row_buckets + self.supports_opt = data.dtype in (torch.float32, torch.float64) + self.long_part_rows = long_part_rows + self.long_part_starts = long_part_starts + self.long_part_ends = long_part_ends + self.long_row_ids = long_row_ids + self.long_row_part_ptr = long_row_part_ptr + self.long_row_fallback_only = bool(long_row_fallback_only) + + +_SPMM_OPT_BUCKET_SPECS = ( + {"max_row_nnz": 32, "kind": "batched", "batch_rows": 8, "block_nnz": 32}, + {"max_row_nnz": 128, "kind": "batched", "batch_rows": 4, "block_nnz": 64}, + {"max_row_nnz": 512, "kind": "vector", "batch_rows": 1, "block_nnz": 128}, + {"max_row_nnz": 2048, "kind": "vector", "batch_rows": 1, "block_nnz": 128}, + {"max_row_nnz": None, "kind": "split", "batch_rows": 1, "block_nnz": 256}, +) +_SPMM_OPT_LONG_ROW_THRESHOLD = 2048 + + +def _select_spmm_opt_block_n(n_dense_cols): + if n_dense_cols <= 8: + return 8 + if n_dense_cols <= 16: + return 16 + if n_dense_cols <= 32: + return 32 + if n_dense_cols <= 64: + return 64 + return 128 + + +def _build_spmm_opt_split_metadata(kernel_indptr, long_rows, part_block_nnz): + device = kernel_indptr.device + row_values = [] + part_starts = [] + part_ends = [] + row_part_ptr = [0] + long_rows_cpu = long_rows.to(torch.int64).cpu().tolist() + indptr_cpu = kernel_indptr.to(torch.int64).cpu().tolist() + for row in long_rows_cpu: + start = int(indptr_cpu[row]) + end = int(indptr_cpu[row + 1]) + cursor = start + while cursor < end: + row_values.append(row) + part_starts.append(cursor) + cursor = min(cursor + int(part_block_nnz), end) + part_ends.append(cursor) + row_part_ptr.append(len(row_values)) + row_dtype = long_rows.dtype if long_rows.numel() > 0 else torch.int32 + return ( + torch.tensor(row_values, dtype=row_dtype, device=device), + torch.tensor(part_starts, dtype=torch.int64, device=device), + torch.tensor(part_ends, dtype=torch.int64, device=device), + torch.tensor(row_part_ptr, dtype=torch.int64, device=device), + ) + + +def _build_spmm_opt_buckets(row_lengths): + device = row_lengths.device + max_row_index = int(row_lengths.numel()) - 1 + row_index_dtype = ( + torch.int32 if max_row_index <= _INDEX_LIMIT_INT32 else torch.int64 + ) + buckets = [] + lower = 0 + long_rows = torch.empty((0,), dtype=row_index_dtype, device=device) + for spec in _SPMM_OPT_BUCKET_SPECS: + upper = spec["max_row_nnz"] + if upper is None: + mask = row_lengths > lower + elif lower == 0: + mask = row_lengths <= upper + else: + mask = (row_lengths > lower) & (row_lengths <= upper) + rows = torch.nonzero(mask, as_tuple=False).flatten().to(row_index_dtype) + if rows.numel() == 0: + if upper is not None: + lower = upper + continue + bucket = { + "kind": spec["kind"], + "rows": rows, + "batch_rows": int(spec["batch_rows"]), + "block_nnz": int(spec["block_nnz"]), + } + buckets.append(bucket) + if spec["kind"] == "split": + long_rows = rows + if upper is not None: + lower = upper + return buckets, long_rows + + +def prepare_spmm_csr_opt(data, indices, indptr, shape): + ( + data, + kernel_indices, + kernel_indptr, + n_rows, + n_cols, + row_lengths, + max_row_nnz, + ) = _prepare_spmm_csr_matrix(data, indices, indptr, shape) + row_buckets, long_rows = _build_spmm_opt_buckets(row_lengths) + long_part_rows = torch.empty((0,), dtype=long_rows.dtype, device=data.device) + long_part_starts = torch.empty((0,), dtype=torch.int64, device=data.device) + long_part_ends = torch.empty((0,), dtype=torch.int64, device=data.device) + long_row_part_ptr = torch.zeros(1, dtype=torch.int64, device=data.device) + if long_rows.numel() > 0: + ( + long_part_rows, + long_part_starts, + long_part_ends, + long_row_part_ptr, + ) = _build_spmm_opt_split_metadata( + kernel_indptr, + long_rows, + part_block_nnz=256, + ) + return PreparedCsrSpmmOpt( + data=data, + kernel_indices=kernel_indices, + kernel_indptr=kernel_indptr, + shape=shape, + n_rows=n_rows, + n_cols=n_cols, + row_lengths=row_lengths, + max_row_nnz=max_row_nnz, + row_buckets=row_buckets, + long_part_rows=long_part_rows, + long_part_starts=long_part_starts, + long_part_ends=long_part_ends, + long_row_ids=long_rows, + long_row_part_ptr=long_row_part_ptr, + long_row_fallback_only=False, + ) + + +def _triton_spmm_csr_impl( + data, + indices, + indptr, + B, + n_rows, + n_dense_cols, + block_n, + block_nnz, +): + device = data.device + dtype = data.dtype + if n_rows == 0 or n_dense_cols == 0 or B.shape[0] == 0: + return torch.zeros((n_rows, n_dense_cols), dtype=dtype, device=device) + + if not _is_complex_dtype(dtype): + compute_dtype = dtype + data_in = data + B_in = B + if dtype in (torch.float16, torch.bfloat16): + compute_dtype = torch.float32 + data_in = data.to(torch.float32) + B_in = B.to(torch.float32) + elif dtype == torch.float32: + compute_dtype = torch.float64 + data_in = data.to(torch.float64) + B_in = B.to(torch.float64) + + C_compute = torch.empty((n_rows, n_dense_cols), dtype=compute_dtype, device=device) + grid = (n_rows, triton.cdiv(n_dense_cols, block_n)) + acc_dtype = tl.float64 if compute_dtype == torch.float64 else tl.float32 + _spmm_csr_real_kernel[grid]( + data_in, + indices, + indptr, + B_in, + C_compute, + n_rows, + n_dense_cols, + B_in.stride(0), + B_in.stride(1), + C_compute.stride(0), + C_compute.stride(1), + BLOCK_N=block_n, + BLOCK_NNZ=block_nnz, + ACC_DTYPE=acc_dtype, + ) + if compute_dtype != dtype: + C_compute = C_compute.to(dtype) + return C_compute + + data_ri = torch.view_as_real(data).contiguous().reshape(-1) + B_ri = torch.view_as_real(B).contiguous() + C_ri = torch.empty((n_rows, n_dense_cols, 2), dtype=B_ri.dtype, device=device) + grid = (n_rows, triton.cdiv(n_dense_cols, block_n)) + acc_dtype = tl.float64 if B_ri.dtype == torch.float64 else tl.float32 + _spmm_csr_complex_kernel[grid]( + data_ri, + indices, + indptr, + B_ri, + C_ri, + n_rows, + n_dense_cols, + B_ri.stride(0), + B_ri.stride(1), + B_ri.stride(2), + C_ri.stride(0), + C_ri.stride(1), + C_ri.stride(2), + BLOCK_N=block_n, + BLOCK_NNZ=block_nnz, + ACC_DTYPE=acc_dtype, + ) + return torch.view_as_complex(C_ri.contiguous()) + + +@triton.jit +def _spmm_csr_batched_rows_f32_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BATCH: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + for batch_idx in tl.static_range(0, BATCH): + ridx = pid_row * BATCH + batch_idx + active = ridx < n_bucket_rows + row = tl.load(rows_ptr + ridx, mask=active, other=0) + start = tl.load(indptr_ptr + row, mask=active, other=0) + end = tl.load(indptr_ptr + row + 1, mask=active, other=0) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = active & (idx < end) + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + acc = acc + a_val * b_vals + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n & active) + + +@triton.jit +def _spmm_csr_batched_rows_f64_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BATCH: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + for batch_idx in tl.static_range(0, BATCH): + ridx = pid_row * BATCH + batch_idx + active = ridx < n_bucket_rows + row = tl.load(rows_ptr + ridx, mask=active, other=0) + start = tl.load(indptr_ptr + row, mask=active, other=0) + end = tl.load(indptr_ptr + row + 1, mask=active, other=0) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = active & (idx < end) + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + acc = acc + a_val * b_vals + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n & active) + + +@triton.jit +def _spmm_csr_vector_rows_f32_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_bucket_rows: + return + row = tl.load(rows_ptr + pid_row) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = idx < end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + acc = acc + a_val * b_vals + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_vector_rows_f64_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_bucket_rows: + return + row = tl.load(rows_ptr + pid_row) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = idx < end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + acc = acc + a_val * b_vals + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_split_part_f32_kernel( + data_ptr, + indices_ptr, + b_ptr, + workspace_ptr, + part_starts_ptr, + part_ends_ptr, + n_parts, + n_dense_cols, + stride_bk, + stride_bn, + stride_wm, + stride_wn, + BLOCK_N: tl.constexpr, +): + pid_part = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_part >= n_parts: + return + start = tl.load(part_starts_ptr + pid_part) + end = tl.load(part_ends_ptr + pid_part) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + for idx in tl.range(start, end): + a_val = tl.load(data_ptr + idx) + a_col = tl.load(indices_ptr + idx) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n, + other=0.0, + ) + acc = acc + a_val * b_vals + tl.store(workspace_ptr + pid_part * stride_wm + offs_n * stride_wn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_split_part_f64_kernel( + data_ptr, + indices_ptr, + b_ptr, + workspace_ptr, + part_starts_ptr, + part_ends_ptr, + n_parts, + n_dense_cols, + stride_bk, + stride_bn, + stride_wm, + stride_wn, + BLOCK_N: tl.constexpr, +): + pid_part = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_part >= n_parts: + return + start = tl.load(part_starts_ptr + pid_part) + end = tl.load(part_ends_ptr + pid_part) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for idx in tl.range(start, end): + a_val = tl.load(data_ptr + idx) + a_col = tl.load(indices_ptr + idx) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n, + other=0.0, + ) + acc = acc + a_val * b_vals + tl.store(workspace_ptr + pid_part * stride_wm + offs_n * stride_wn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_split_reduce_f32_kernel( + workspace_ptr, + out_ptr, + long_rows_ptr, + row_part_ptr_ptr, + n_long_rows, + n_dense_cols, + stride_wm, + stride_wn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_long_rows: + return + row = tl.load(long_rows_ptr + pid_row) + part_start = tl.load(row_part_ptr_ptr + pid_row) + part_end = tl.load(row_part_ptr_ptr + pid_row + 1) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + for rel_part in tl.range(0, part_end - part_start): + part_id = part_start + rel_part + vals = tl.load( + workspace_ptr + part_id * stride_wm + offs_n * stride_wn, + mask=mask_n, + other=0.0, + ) + acc = acc + vals + tl.store(out_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_split_reduce_f64_kernel( + workspace_ptr, + out_ptr, + long_rows_ptr, + row_part_ptr_ptr, + n_long_rows, + n_dense_cols, + stride_wm, + stride_wn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_long_rows: + return + row = tl.load(long_rows_ptr + pid_row) + part_start = tl.load(row_part_ptr_ptr + pid_row) + part_end = tl.load(row_part_ptr_ptr + pid_row + 1) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for rel_part in tl.range(0, part_end - part_start): + part_id = part_start + rel_part + vals = tl.load( + workspace_ptr + part_id * stride_wm + offs_n * stride_wn, + mask=mask_n, + other=0.0, + ) + acc = acc + vals + tl.store(out_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_stable_batched_f32_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BATCH: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + for batch_idx in tl.static_range(0, BATCH): + ridx = pid_row * BATCH + batch_idx + active = ridx < n_bucket_rows + row = tl.load(rows_ptr + ridx, mask=active, other=0) + start = tl.load(indptr_ptr + row, mask=active, other=0) + end = tl.load(indptr_ptr + row + 1, mask=active, other=0) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=tl.float32) + comp = tl.zeros([BLOCK_N], dtype=tl.float32) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = active & (idx < end) + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + prod = a_val.to(tl.float32) * b_vals.to(tl.float32) + y = prod - comp + t = acc + y + comp = (t - acc) - y + acc = t + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n & active) + + +@triton.jit +def _spmm_csr_stable_batched_f64_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BATCH: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + for batch_idx in tl.static_range(0, BATCH): + ridx = pid_row * BATCH + batch_idx + active = ridx < n_bucket_rows + row = tl.load(rows_ptr + ridx, mask=active, other=0) + start = tl.load(indptr_ptr + row, mask=active, other=0) + end = tl.load(indptr_ptr + row + 1, mask=active, other=0) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + comp = tl.zeros([BLOCK_N], dtype=tl.float64) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = active & (idx < end) + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + prod = a_val.to(tl.float64) * b_vals.to(tl.float64) + y = prod - comp + t = acc + y + comp = (t - acc) - y + acc = t + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n & active) + + +@triton.jit +def _spmm_csr_stable_vector_f32_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_bucket_rows: + return + row = tl.load(rows_ptr + pid_row) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = idx < end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + acc = acc + a_val.to(tl.float64) * b_vals.to(tl.float64) + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc.to(tl.float32), mask=mask_n) + + +@triton.jit +def _spmm_csr_stable_vector_f64_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_bucket_rows: + return + row = tl.load(rows_ptr + pid_row) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + row_nnz = end - start + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for chunk_start in tl.range(0, row_nnz, BLOCK_NNZ): + for kk in tl.static_range(0, BLOCK_NNZ): + idx = start + chunk_start + kk + valid = idx < end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + acc = acc + a_val.to(tl.float64) * b_vals.to(tl.float64) + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_stable_split_part_f32_kernel( + data_ptr, + indices_ptr, + b_ptr, + workspace_ptr, + part_starts_ptr, + part_ends_ptr, + n_parts, + n_dense_cols, + stride_bk, + stride_bn, + stride_wm, + stride_wn, + BLOCK_N: tl.constexpr, +): + pid_part = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_part >= n_parts: + return + start = tl.load(part_starts_ptr + pid_part) + end = tl.load(part_ends_ptr + pid_part) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for idx in tl.range(start, end): + a_val = tl.load(data_ptr + idx) + a_col = tl.load(indices_ptr + idx) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n, + other=0.0, + ) + acc = acc + a_val.to(tl.float64) * b_vals.to(tl.float64) + tl.store(workspace_ptr + pid_part * stride_wm + offs_n * stride_wn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_stable_split_part_f64_kernel( + data_ptr, + indices_ptr, + b_ptr, + workspace_ptr, + part_starts_ptr, + part_ends_ptr, + n_parts, + n_dense_cols, + stride_bk, + stride_bn, + stride_wm, + stride_wn, + BLOCK_N: tl.constexpr, +): + pid_part = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_part >= n_parts: + return + start = tl.load(part_starts_ptr + pid_part) + end = tl.load(part_ends_ptr + pid_part) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for idx in tl.range(start, end): + a_val = tl.load(data_ptr + idx) + a_col = tl.load(indices_ptr + idx) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n, + other=0.0, + ) + acc = acc + a_val.to(tl.float64) * b_vals.to(tl.float64) + tl.store(workspace_ptr + pid_part * stride_wm + offs_n * stride_wn, acc, mask=mask_n) + + +@triton.jit +def _spmm_csr_stable_split_reduce_f32_kernel( + workspace_ptr, + out_ptr, + long_rows_ptr, + row_part_ptr_ptr, + n_long_rows, + n_dense_cols, + stride_wm, + stride_wn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_long_rows: + return + row = tl.load(long_rows_ptr + pid_row) + part_start = tl.load(row_part_ptr_ptr + pid_row) + part_end = tl.load(row_part_ptr_ptr + pid_row + 1) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for rel_part in tl.range(0, part_end - part_start): + part_id = part_start + rel_part + vals = tl.load( + workspace_ptr + part_id * stride_wm + offs_n * stride_wn, + mask=mask_n, + other=0.0, + ) + acc = acc + vals.to(tl.float64) + tl.store(out_ptr + row * stride_cm + offs_n * stride_cn, acc.to(tl.float32), mask=mask_n) + + +@triton.jit +def _spmm_csr_stable_split_reduce_f64_kernel( + workspace_ptr, + out_ptr, + long_rows_ptr, + row_part_ptr_ptr, + n_long_rows, + n_dense_cols, + stride_wm, + stride_wn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_long_rows: + return + row = tl.load(long_rows_ptr + pid_row) + part_start = tl.load(row_part_ptr_ptr + pid_row) + part_end = tl.load(row_part_ptr_ptr + pid_row + 1) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + acc = tl.zeros([BLOCK_N], dtype=tl.float64) + for rel_part in tl.range(0, part_end - part_start): + part_id = part_start + rel_part + vals = tl.load( + workspace_ptr + part_id * stride_wm + offs_n * stride_wn, + mask=mask_n, + other=0.0, + ) + acc = acc + vals.to(tl.float64) + tl.store(out_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +def _run_spmm_opt_bucket(prepared, bucket, B, C_out, block_n): + rows = bucket["rows"] + if rows.numel() == 0: + return + dtype = prepared.data.dtype + kernel_map = { + ("batched", torch.float32): _spmm_csr_batched_rows_f32_kernel, + ("batched", torch.float64): _spmm_csr_batched_rows_f64_kernel, + ("vector", torch.float32): _spmm_csr_vector_rows_f32_kernel, + ("vector", torch.float64): _spmm_csr_vector_rows_f64_kernel, + } + if bucket["kind"] == "batched": + batch_rows = int(bucket["batch_rows"]) + grid = (triton.cdiv(rows.numel(), batch_rows), triton.cdiv(B.shape[1], block_n)) + kernel = kernel_map[(bucket["kind"], dtype)] + kernel[grid]( + prepared.data, + prepared.kernel_indices, + prepared.kernel_indptr, + B, + C_out, + rows, + rows.numel(), + B.shape[1], + B.stride(0), + B.stride(1), + C_out.stride(0), + C_out.stride(1), + BATCH=batch_rows, + BLOCK_N=block_n, + BLOCK_NNZ=bucket["block_nnz"], + ) + return + + grid = (rows.numel(), triton.cdiv(B.shape[1], block_n)) + kernel = kernel_map[(bucket["kind"], dtype)] + kernel[grid]( + prepared.data, + prepared.kernel_indices, + prepared.kernel_indptr, + B, + C_out, + rows, + rows.numel(), + B.shape[1], + B.stride(0), + B.stride(1), + C_out.stride(0), + C_out.stride(1), + BLOCK_N=block_n, + BLOCK_NNZ=bucket["block_nnz"], + ) + + +def _run_spmm_opt_split_bucket(prepared, B, C_out, block_n): + if prepared.long_part_rows.numel() == 0: + return False + workspace = torch.empty( + (prepared.long_part_rows.numel(), B.shape[1]), + dtype=B.dtype, + device=B.device, + ) + split_kernel = ( + _spmm_csr_split_part_f64_kernel + if B.dtype == torch.float64 + else _spmm_csr_split_part_f32_kernel + ) + reduce_kernel = ( + _spmm_csr_split_reduce_f64_kernel + if B.dtype == torch.float64 + else _spmm_csr_split_reduce_f32_kernel + ) + split_grid = ( + prepared.long_part_rows.numel(), + triton.cdiv(B.shape[1], block_n), + ) + split_kernel[split_grid]( + prepared.data, + prepared.kernel_indices, + B, + workspace, + prepared.long_part_starts, + prepared.long_part_ends, + prepared.long_part_rows.numel(), + B.shape[1], + B.stride(0), + B.stride(1), + workspace.stride(0), + workspace.stride(1), + BLOCK_N=block_n, + ) + reduce_grid = ( + prepared.long_row_ids.numel(), + triton.cdiv(B.shape[1], block_n), + ) + reduce_kernel[reduce_grid]( + workspace, + C_out, + prepared.long_row_ids, + prepared.long_row_part_ptr, + prepared.long_row_ids.numel(), + B.shape[1], + workspace.stride(0), + workspace.stride(1), + C_out.stride(0), + C_out.stride(1), + BLOCK_N=block_n, + ) + return False + + +def _run_spmm_opt_bucket_stable(prepared, bucket, B, C_out, block_n): + rows = bucket["rows"] + if rows.numel() == 0: + return + dtype = prepared.data.dtype + kernel_map = { + ("vector", torch.float32): _spmm_csr_stable_vector_f32_kernel, + ("vector", torch.float64): _spmm_csr_stable_vector_f64_kernel, + } + if bucket["kind"] == "batched": + # Diagnose-only stable path: run short-row buckets through the + # row-per-program fp64-accum vector kernel instead of the merged + # batched kernel. This favors numerical stability over throughput. + kind = "vector" + else: + kind = bucket["kind"] + + grid = (rows.numel(), triton.cdiv(B.shape[1], block_n)) + kernel = kernel_map[(kind, dtype)] + kernel[grid]( + prepared.data, + prepared.kernel_indices, + prepared.kernel_indptr, + B, + C_out, + rows, + rows.numel(), + B.shape[1], + B.stride(0), + B.stride(1), + C_out.stride(0), + C_out.stride(1), + BLOCK_N=block_n, + BLOCK_NNZ=bucket["block_nnz"], + ) + + +def _run_spmm_opt_split_bucket_stable(prepared, B, C_out, block_n): + if prepared.long_part_rows.numel() == 0: + return False + workspace = torch.empty( + (prepared.long_part_rows.numel(), B.shape[1]), + dtype=torch.float64, + device=B.device, + ) + split_kernel = ( + _spmm_csr_stable_split_part_f64_kernel + if B.dtype == torch.float64 + else _spmm_csr_stable_split_part_f32_kernel + ) + reduce_kernel = ( + _spmm_csr_stable_split_reduce_f64_kernel + if B.dtype == torch.float64 + else _spmm_csr_stable_split_reduce_f32_kernel + ) + split_grid = ( + prepared.long_part_rows.numel(), + triton.cdiv(B.shape[1], block_n), + ) + split_kernel[split_grid]( + prepared.data, + prepared.kernel_indices, + B, + workspace, + prepared.long_part_starts, + prepared.long_part_ends, + prepared.long_part_rows.numel(), + B.shape[1], + B.stride(0), + B.stride(1), + workspace.stride(0), + workspace.stride(1), + BLOCK_N=block_n, + ) + reduce_grid = ( + prepared.long_row_ids.numel(), + triton.cdiv(B.shape[1], block_n), + ) + reduce_kernel[reduce_grid]( + workspace, + C_out, + prepared.long_row_ids, + prepared.long_row_part_ptr, + prepared.long_row_ids.numel(), + B.shape[1], + workspace.stride(0), + workspace.stride(1), + C_out.stride(0), + C_out.stride(1), + BLOCK_N=block_n, + ) + return False + + +_SPMM_OPT_CANDIDATE_SHORT_BUCKET_SPECS = ( + {"label": "batched_16", "min_row_nnz": 0, "max_row_nnz": 16, "block_nnz": 16, "segments": 2}, + {"label": "batched_32", "min_row_nnz": 17, "max_row_nnz": 32, "block_nnz": 32, "segments": 4}, + {"label": "batched_64", "min_row_nnz": 33, "max_row_nnz": 64, "block_nnz": 64, "segments": 4}, + {"label": "batched_128", "min_row_nnz": 65, "max_row_nnz": 128, "block_nnz": 128, "segments": 8}, +) + +_SPMM_OPT_CANDIDATE_VECTOR_BUCKET_SPECS = ( + {"label": "vector_256", "min_row_nnz": 129, "max_row_nnz": 256, "block_nnz": 64, "segments": 4}, + {"label": "vector_512", "min_row_nnz": 257, "max_row_nnz": 512, "block_nnz": 64, "segments": 8}, + {"label": "vector_1024", "min_row_nnz": 513, "max_row_nnz": 1024, "block_nnz": 32, "segments": 8}, +) + +@triton.jit +def _spmm_csr_candidate_rows_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + c_ptr, + rows_ptr, + n_bucket_rows, + n_dense_cols, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_N: tl.constexpr, + BLOCK_NNZ: tl.constexpr, + SEGMENTS: tl.constexpr, + ACC_DTYPE: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_n = tl.program_id(1) + if pid_row >= n_bucket_rows: + return + row = tl.load(rows_ptr + pid_row) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask_n = offs_n < n_dense_cols + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + row_nnz = end - start + seg_span = (row_nnz + SEGMENTS - 1) // SEGMENTS + acc0 = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc1 = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc2 = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc3 = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc4 = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc5 = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc6 = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + acc7 = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + + if SEGMENTS > 0: + seg_start = start + seg_end = tl.minimum(end, seg_start + seg_span) + for chunk_offset in tl.range(0, seg_end - seg_start, BLOCK_NNZ): + chunk_acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + for kk in tl.static_range(0, BLOCK_NNZ): + idx = seg_start + chunk_offset + kk + valid = idx < seg_end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + chunk_acc = chunk_acc + a_val * b_vals + acc0 = acc0 + chunk_acc + + if SEGMENTS > 1: + seg_start = start + seg_span + seg_end = tl.minimum(end, seg_start + seg_span) + for chunk_offset in tl.range(0, seg_end - seg_start, BLOCK_NNZ): + chunk_acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + for kk in tl.static_range(0, BLOCK_NNZ): + idx = seg_start + chunk_offset + kk + valid = idx < seg_end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + chunk_acc = chunk_acc + a_val * b_vals + acc1 = acc1 + chunk_acc + + if SEGMENTS > 2: + seg_start = start + 2 * seg_span + seg_end = tl.minimum(end, seg_start + seg_span) + for chunk_offset in tl.range(0, seg_end - seg_start, BLOCK_NNZ): + chunk_acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + for kk in tl.static_range(0, BLOCK_NNZ): + idx = seg_start + chunk_offset + kk + valid = idx < seg_end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + chunk_acc = chunk_acc + a_val * b_vals + acc2 = acc2 + chunk_acc + + if SEGMENTS > 3: + seg_start = start + 3 * seg_span + seg_end = tl.minimum(end, seg_start + seg_span) + for chunk_offset in tl.range(0, seg_end - seg_start, BLOCK_NNZ): + chunk_acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + for kk in tl.static_range(0, BLOCK_NNZ): + idx = seg_start + chunk_offset + kk + valid = idx < seg_end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + chunk_acc = chunk_acc + a_val * b_vals + acc3 = acc3 + chunk_acc + + if SEGMENTS > 4: + seg_start = start + 4 * seg_span + seg_end = tl.minimum(end, seg_start + seg_span) + for chunk_offset in tl.range(0, seg_end - seg_start, BLOCK_NNZ): + chunk_acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + for kk in tl.static_range(0, BLOCK_NNZ): + idx = seg_start + chunk_offset + kk + valid = idx < seg_end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + chunk_acc = chunk_acc + a_val * b_vals + acc4 = acc4 + chunk_acc + + if SEGMENTS > 5: + seg_start = start + 5 * seg_span + seg_end = tl.minimum(end, seg_start + seg_span) + for chunk_offset in tl.range(0, seg_end - seg_start, BLOCK_NNZ): + chunk_acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + for kk in tl.static_range(0, BLOCK_NNZ): + idx = seg_start + chunk_offset + kk + valid = idx < seg_end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + chunk_acc = chunk_acc + a_val * b_vals + acc5 = acc5 + chunk_acc + + if SEGMENTS > 6: + seg_start = start + 6 * seg_span + seg_end = tl.minimum(end, seg_start + seg_span) + for chunk_offset in tl.range(0, seg_end - seg_start, BLOCK_NNZ): + chunk_acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + for kk in tl.static_range(0, BLOCK_NNZ): + idx = seg_start + chunk_offset + kk + valid = idx < seg_end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + chunk_acc = chunk_acc + a_val * b_vals + acc6 = acc6 + chunk_acc + + if SEGMENTS > 7: + seg_start = start + 7 * seg_span + seg_end = tl.minimum(end, seg_start + seg_span) + for chunk_offset in tl.range(0, seg_end - seg_start, BLOCK_NNZ): + chunk_acc = tl.zeros([BLOCK_N], dtype=ACC_DTYPE) + for kk in tl.static_range(0, BLOCK_NNZ): + idx = seg_start + chunk_offset + kk + valid = idx < seg_end + a_val = tl.load(data_ptr + idx, mask=valid, other=0.0) + a_col = tl.load(indices_ptr + idx, mask=valid, other=0) + b_vals = tl.load( + b_ptr + a_col * stride_bk + offs_n * stride_bn, + mask=mask_n & valid, + other=0.0, + ) + chunk_acc = chunk_acc + a_val * b_vals + acc7 = acc7 + chunk_acc + + if SEGMENTS == 1: + acc = acc0 + elif SEGMENTS == 2: + acc = acc0 + acc1 + elif SEGMENTS <= 4: + acc_left = acc0 + acc1 + acc_right = acc2 + acc3 + acc = acc_left + acc_right + else: + acc01 = acc0 + acc1 + acc23 = acc2 + acc3 + acc45 = acc4 + acc5 + acc67 = acc6 + acc7 + acc = (acc01 + acc23) + (acc45 + acc67) + tl.store(c_ptr + row * stride_cm + offs_n * stride_cn, acc, mask=mask_n) + + +def _build_spmm_opt_candidate_buckets(prepared): + row_lengths = prepared.row_lengths + device = row_lengths.device + row_count = int(row_lengths.numel()) + max_row_index = row_count - 1 + row_index_dtype = torch.int32 if max_row_index <= _INDEX_LIMIT_INT32 else torch.int64 + all_rows = torch.arange(row_count, device=device, dtype=row_index_dtype) + buckets = [] + for spec in _SPMM_OPT_CANDIDATE_SHORT_BUCKET_SPECS: + lower = int(spec["min_row_nnz"]) + upper = int(spec["max_row_nnz"]) + if lower <= 0: + mask = row_lengths <= upper + else: + mask = (row_lengths >= lower) & (row_lengths <= upper) + rows = all_rows[mask] + if rows.numel() == 0: + continue + buckets.append( + { + "label": spec["label"], + "kind": "short", + "rows": rows, + "block_nnz": int(spec["block_nnz"]), + "segments": int(spec["segments"]), + } + ) + + for spec in _SPMM_OPT_CANDIDATE_VECTOR_BUCKET_SPECS: + lower = int(spec["min_row_nnz"]) + upper = int(spec["max_row_nnz"]) + mask = (row_lengths >= lower) & (row_lengths <= upper) + rows = all_rows[mask] + if rows.numel() == 0: + continue + buckets.append( + { + "label": spec["label"], + "kind": "vector", + "rows": rows, + "block_nnz": int(spec["block_nnz"]), + "segments": int(spec["segments"]), + } + ) + + split_rows = all_rows[row_lengths > 1024] + if split_rows.numel() > 0: + buckets.append({"label": "split", "kind": "split", "rows": split_rows, "block_nnz": 256}) + return buckets + + +def _run_spmm_opt_bucket_candidate(prepared, bucket, B, C_out, block_n): + rows = bucket["rows"] + if rows.numel() == 0: + return + dtype = prepared.data.dtype + acc_dtype = tl.float64 if dtype == torch.float64 else tl.float32 + grid = (rows.numel(), triton.cdiv(B.shape[1], block_n)) + _spmm_csr_candidate_rows_kernel[grid]( + prepared.data, + prepared.kernel_indices, + prepared.kernel_indptr, + B, + C_out, + rows, + rows.numel(), + B.shape[1], + B.stride(0), + B.stride(1), + C_out.stride(0), + C_out.stride(1), + BLOCK_N=block_n, + BLOCK_NNZ=bucket["block_nnz"], + SEGMENTS=bucket["segments"], + ACC_DTYPE=acc_dtype, + ) + + +def _run_spmm_opt_split_bucket_candidate(prepared, split_rows, B, C_out, block_n): + if split_rows.numel() == 0: + return False + ( + long_part_rows, + long_part_starts, + long_part_ends, + long_row_part_ptr, + ) = _build_spmm_opt_split_metadata( + prepared.kernel_indptr, + split_rows.to(prepared.kernel_indptr.device), + part_block_nnz=256, + ) + if long_part_rows.numel() == 0: + return False + workspace = torch.empty((long_part_rows.numel(), B.shape[1]), dtype=B.dtype, device=B.device) + split_kernel = _spmm_csr_split_part_f64_kernel if B.dtype == torch.float64 else _spmm_csr_split_part_f32_kernel + reduce_kernel = _spmm_csr_split_reduce_f64_kernel if B.dtype == torch.float64 else _spmm_csr_split_reduce_f32_kernel + split_grid = (long_part_rows.numel(), triton.cdiv(B.shape[1], block_n)) + split_kernel[split_grid]( + prepared.data, + prepared.kernel_indices, + B, + workspace, + long_part_starts, + long_part_ends, + long_part_rows.numel(), + B.shape[1], + B.stride(0), + B.stride(1), + workspace.stride(0), + workspace.stride(1), + BLOCK_N=block_n, + ) + reduce_grid = (split_rows.numel(), triton.cdiv(B.shape[1], block_n)) + reduce_kernel[reduce_grid]( + workspace, + C_out, + split_rows, + long_row_part_ptr, + split_rows.numel(), + B.shape[1], + workspace.stride(0), + workspace.stride(1), + C_out.stride(0), + C_out.stride(1), + BLOCK_N=block_n, + ) + return False + + +def _triton_spmm_csr_impl_opt_prepared_candidate(prepared, B): + if not prepared.supports_opt: + raise TypeError("spmm opt only supports float32 and float64") + if not B.is_cuda: + raise ValueError("B must be a CUDA tensor") + if B.device != prepared.data.device: + raise ValueError("B must be on the same CUDA device as the sparse matrix") + if B.dtype != prepared.data.dtype: + raise TypeError("B dtype must match sparse matrix dtype") + if B.shape[0] != prepared.n_cols: + raise ValueError(f"B.shape[0] must be n_cols={prepared.n_cols}, got {B.shape[0]}") + if B.ndim != 2: + raise ValueError("B must be a 2D dense tensor") + B = B.contiguous() + block_n = _select_spmm_opt_block_n(int(B.shape[1])) + C_out = torch.zeros((prepared.n_rows, int(B.shape[1])), dtype=B.dtype, device=B.device) + long_row_fallback_used = False + for bucket in _build_spmm_opt_candidate_buckets(prepared): + if bucket["kind"] == "split": + long_row_fallback_used = _run_spmm_opt_split_bucket_candidate( + prepared, + bucket["rows"], + B, + C_out, + block_n, + ) + continue + _run_spmm_opt_bucket_candidate(prepared, bucket, B, C_out, block_n) + return C_out, long_row_fallback_used + + +def _flagsparse_spmm_csr_opt_candidate_for_diagnose(prepared, B): + C_out, _ = _triton_spmm_csr_impl_opt_prepared_candidate(prepared, B) + return C_out + + +def _triton_spmm_csr_impl_opt_prepared(prepared, B): + if not prepared.supports_opt: + raise TypeError("spmm opt only supports float32 and float64") + if not B.is_cuda: + raise ValueError("B must be a CUDA tensor") + if B.device != prepared.data.device: + raise ValueError("B must be on the same CUDA device as the sparse matrix") + if B.dtype != prepared.data.dtype: + raise TypeError("B dtype must match sparse matrix dtype") + if B.shape[0] != prepared.n_cols: + raise ValueError(f"B.shape[0] must be n_cols={prepared.n_cols}, got {B.shape[0]}") + if B.ndim != 2: + raise ValueError("B must be a 2D dense tensor") + B = B.contiguous() + block_n = _select_spmm_opt_block_n(int(B.shape[1])) + C_out = torch.zeros((prepared.n_rows, int(B.shape[1])), dtype=B.dtype, device=B.device) + long_row_fallback_used = False + for bucket in prepared.row_buckets: + if bucket["kind"] == "split": + long_row_fallback_used = _run_spmm_opt_split_bucket(prepared, B, C_out, block_n) + continue + _run_spmm_opt_bucket(prepared, bucket, B, C_out, block_n) + return C_out, long_row_fallback_used + + +def _triton_spmm_csr_impl_opt_prepared_stable(prepared, B): + if not prepared.supports_opt: + raise TypeError("spmm opt only supports float32 and float64") + if not B.is_cuda: + raise ValueError("B must be a CUDA tensor") + if B.device != prepared.data.device: + raise ValueError("B must be on the same CUDA device as the sparse matrix") + if B.dtype != prepared.data.dtype: + raise TypeError("B dtype must match sparse matrix dtype") + if B.shape[0] != prepared.n_cols: + raise ValueError(f"B.shape[0] must be n_cols={prepared.n_cols}, got {B.shape[0]}") + if B.ndim != 2: + raise ValueError("B must be a 2D dense tensor") + B = B.contiguous() + block_n = _select_spmm_opt_block_n(int(B.shape[1])) + C_out = torch.zeros((prepared.n_rows, int(B.shape[1])), dtype=B.dtype, device=B.device) + long_row_fallback_used = False + for bucket in prepared.row_buckets: + if bucket["kind"] == "split": + long_row_fallback_used = _run_spmm_opt_split_bucket_stable(prepared, B, C_out, block_n) + continue + _run_spmm_opt_bucket_stable(prepared, bucket, B, C_out, block_n) + return C_out, long_row_fallback_used + + +def _flagsparse_spmm_csr_opt_stable_for_diagnose(prepared, B): + C_out, _ = _triton_spmm_csr_impl_opt_prepared_stable(prepared, B) + return C_out + +def flagsparse_spmm_csr( + data, + indices, + indptr, + B, + shape, + block_n=None, + block_nnz=None, + max_segments=None, + out=None, + return_time=False, +): + """CSR SpMM: C = A @ B using Triton. + + A is provided as CSR arrays; B is a dense CUDA tensor with shape (n_cols, n_dense_cols). + This staged implementation is the row-major, non-transpose subset of + AlphaSparse CSR ALG1 (`csrspmm_rb_sr`) expressed in Triton. + """ + if block_n is not None and block_n <= 0: + raise ValueError("block_n must be positive when provided") + if block_nnz is not None and block_nnz <= 0: + raise ValueError("block_nnz must be positive when provided") + if max_segments is not None and max_segments <= 0: + raise ValueError("max_segments must be positive when provided") + + data, kernel_indices, kernel_indptr, B, n_rows, _, n_dense_cols = _prepare_spmm_csr_inputs( + data, indices, indptr, B, shape + ) + max_row_nnz = ( + int(torch.max(kernel_indptr[1:] - kernel_indptr[:-1]).item()) + if n_rows > 0 + else 0 + ) + launch = _resolve_spmm_alg1_launch_config( + n_dense_cols, + max_row_nnz, + block_n=block_n, + block_nnz=block_nnz, + max_segments=max_segments, + ) + + if out is not None: + if not out.is_cuda: + raise ValueError("out must be a CUDA tensor") + if out.device != data.device: + raise ValueError("out must be on the same CUDA device as the inputs") + if out.shape != (n_rows, n_dense_cols) or out.dtype != data.dtype: + raise ValueError("out shape/dtype must match result") + + torch.cuda.synchronize() + t0 = time.perf_counter() + C = _triton_spmm_csr_impl( + data, + kernel_indices, + kernel_indptr, + B, + n_rows, + n_dense_cols, + block_n=launch["block_n"], + block_nnz=launch["block_nnz"], + ) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + + if out is not None: + out.copy_(C) + C = out + if return_time: + return C, elapsed_ms + return C + + +def flagsparse_spmm_csr_opt( + data=None, + indices=None, + indptr=None, + B=None, + shape=None, + prepared=None, + out=None, + return_time=False, +): + """CSR SpMM-opt: native float32/float64 bucketed path for CSR @ dense.""" + if prepared is not None and not isinstance(prepared, PreparedCsrSpmmOpt): + raise TypeError("prepared must be a PreparedCsrSpmmOpt instance") + if prepared is None: + if any(arg is None for arg in (data, indices, indptr, shape)): + raise ValueError( + "data, indices, indptr, and shape are required when prepared is not provided" + ) + prepared = prepare_spmm_csr_opt(data, indices, indptr, shape) + elif shape is not None: + resolved_shape = (int(shape[0]), int(shape[1])) + if resolved_shape != prepared.shape: + raise ValueError( + f"shape {resolved_shape} does not match prepared.shape {prepared.shape}" + ) + if B is None: + raise ValueError("B is required") + if not prepared.supports_opt: + raise TypeError("flagsparse_spmm_csr_opt only supports float32 and float64") + if not B.is_cuda: + raise ValueError("B must be a CUDA tensor") + if B.device != prepared.data.device: + raise ValueError("B must be on the same CUDA device as sparse matrix data") + if out is not None: + if not out.is_cuda: + raise ValueError("out must be a CUDA tensor") + if out.device != prepared.data.device: + raise ValueError("out must be on the same CUDA device as sparse matrix data") + if out.shape != (prepared.n_rows, int(B.shape[1])) or out.dtype != prepared.data.dtype: + raise ValueError("out shape/dtype must match result") + t0 = None + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + C, _long_row_fallback_used = _triton_spmm_csr_impl_opt_prepared(prepared, B) + elapsed_ms = None + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if out is not None: + out.copy_(C) + C = out + if return_time: + return C, elapsed_ms + return C + + +def _spmm_opt_reference_error(candidate, reference, value_dtype): + atol, rtol = _spmm_coo_reference_tolerance(value_dtype) + if candidate.numel() == 0: + return 0.0 + diff = torch.abs(candidate - reference) + denom = atol + rtol * torch.abs(reference) + return float(torch.max(diff / denom).item()) + + +def benchmark_spmm_opt_case( + n_rows=4096, + n_cols=4096, + nnz=65536, + n_dense_cols=32, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + run_cusparse=True, +): + """Benchmark SpMM base vs opt against the same high-precision PyTorch reference.""" + if value_dtype not in (torch.float32, torch.float64): + raise TypeError("benchmark_spmm_opt_case only supports float32 and float64") + device = torch.device("cuda") + data, indices, indptr = _build_random_csr( + n_rows, n_cols, nnz, value_dtype, index_dtype, device + ) + B = _build_random_dense((n_cols, n_dense_cols), value_dtype, device) + shape = (n_rows, n_cols) + prepared = prepare_spmm_csr_opt(data, indices, indptr, shape) + + torch.cuda.synchronize() + t0 = time.perf_counter() + base_values = flagsparse_spmm_csr(data, indices, indptr, B, shape) + torch.cuda.synchronize() + base_first_call_ms = (time.perf_counter() - t0) * 1000.0 + base_values, base_ms = _benchmark_cuda_op( + lambda: flagsparse_spmm_csr(data, indices, indptr, B, shape), + warmup=warmup, + iters=iters, + ) + + torch.cuda.synchronize() + t0 = time.perf_counter() + opt_values = flagsparse_spmm_csr_opt(B=B, prepared=prepared) + torch.cuda.synchronize() + opt_first_call_ms = (time.perf_counter() - t0) * 1000.0 + opt_values, opt_ms = _benchmark_cuda_op( + lambda: flagsparse_spmm_csr_opt(B=B, prepared=prepared), + warmup=warmup, + iters=iters, + ) + + indptr64 = indptr.to(torch.int64) + indices64 = indices.to(torch.int64) + csr_ref = torch.sparse_csr_tensor( + indptr64, + indices64, + data.to(torch.float64 if value_dtype == torch.float32 else value_dtype), + size=shape, + device=device, + ) + ref = torch.sparse.mm( + csr_ref, + B.to(torch.float64 if value_dtype == torch.float32 else value_dtype), + ).to(value_dtype) + + pt_ms = None + try: + pt_sparse = torch.sparse_csr_tensor( + indptr64, + indices64, + data, + size=shape, + device=device, + ) + pt_op = lambda: torch.sparse.mm(pt_sparse, B) + _, pt_ms = _benchmark_cuda_op(pt_op, warmup=warmup, iters=iters) + except Exception: + pt_ms = None + + cu_ms = None + if run_cusparse and cp is not None and cpx_sparse is not None: + try: + data_cp = _cupy_from_torch(data) + indices_cp = _cupy_from_torch(indices.to(torch.int64)) + indptr_cp = _cupy_from_torch(indptr) + B_cp = _cupy_from_torch(B) + A_csr = cpx_sparse.csr_matrix((data_cp, indices_cp, indptr_cp), shape=shape) + _, cu_ms = _benchmark_cuda_op(lambda: A_csr @ B_cp, warmup=warmup, iters=iters) + except Exception: + cu_ms = None + + err_base = _spmm_opt_reference_error(base_values, ref, value_dtype) + err_opt = _spmm_opt_reference_error(opt_values, ref, value_dtype) + return { + "parameters": { + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": nnz, + "n_dense_cols": n_dense_cols, + "value_dtype": str(value_dtype), + "index_dtype": str(index_dtype), + }, + "performance": { + "base_ms": base_ms, + "base_first_call_ms": base_first_call_ms, + "opt_ms": opt_ms, + "opt_first_call_ms": opt_first_call_ms, + "pt_ms": pt_ms, + "cu_ms": cu_ms, + "opt_vs_base": (base_ms / opt_ms if opt_ms and opt_ms > 0 else None), + "opt_vs_pt": (pt_ms / opt_ms if pt_ms is not None and opt_ms > 0 else None), + "opt_vs_cu": (cu_ms / opt_ms if cu_ms is not None and opt_ms > 0 else None), + }, + "verification": { + "err_base": err_base, + "err_opt": err_opt, + "base_ok": err_base <= 1.0, + "opt_ok": err_opt <= 1.0, + "status": ("PASS" if err_opt <= 1.0 else "FAIL"), + }, + "backend_status": { + "long_row_fallback_used": bool(prepared.long_row_fallback_only), + "flagsparse_internal_route": "csr-opt-bucketed", + }, + "samples": { + "base": base_values, + "opt": opt_values, + "reference": ref, + }, + } + + +def benchmark_spmm_case( + n_rows=4096, + n_cols=4096, + nnz=65536, + n_dense_cols=32, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + block_n=None, + block_nnz=None, + max_segments=None, + run_cusparse=True, +): + """Benchmark Triton CSR SpMM vs PyTorch sparse.mm and CuPy/cuSPARSE CSR @ dense.""" + device = torch.device("cuda") + data, indices, indptr = _build_random_csr( + n_rows, n_cols, nnz, value_dtype, index_dtype, device + ) + B = _build_random_dense((n_cols, n_dense_cols), value_dtype, device) + shape = (n_rows, n_cols) + max_row_nnz = int(torch.max(indptr[1:] - indptr[:-1]).item()) if n_rows > 0 else 0 + launch = _resolve_spmm_alg1_launch_config( + n_dense_cols, + max_row_nnz, + block_n=block_n, + block_nnz=block_nnz, + max_segments=max_segments, + ) + + triton_kwargs = { + "data": data, + "indices": indices, + "indptr": indptr, + "B": B, + "shape": shape, + "block_n": launch["block_n"], + "block_nnz": launch["block_nnz"], + "max_segments": launch["max_segments"], + "return_time": False, + } + + torch.cuda.synchronize() + t0 = time.perf_counter() + _ = flagsparse_spmm_csr(**triton_kwargs) + torch.cuda.synchronize() + triton_first_call_ms = (time.perf_counter() - t0) * 1000.0 + triton_C, triton_ms = _benchmark_cuda_op( + lambda: flagsparse_spmm_csr(**triton_kwargs), + warmup=warmup, + iters=iters, + ) + + indptr64 = indptr.to(torch.int64) + indices64 = indices.to(torch.int64) + row_indices = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + + pytorch_reason = None + pytorch_values = None + pytorch_ms = None + pytorch_format = "CSR" + try: + csr_pt = torch.sparse_csr_tensor(indptr64, indices64, data, size=shape, device=device) + pytorch_op = lambda: torch.sparse.mm(csr_pt, B) + if value_dtype in (torch.float16, torch.bfloat16): + csr_ref = torch.sparse_csr_tensor(indptr64, indices64, data.to(torch.float32), size=shape, device=device) + expected = torch.sparse.mm(csr_ref, B.to(torch.float32)).to(value_dtype) + elif value_dtype == torch.float32: + csr_ref = torch.sparse_csr_tensor(indptr64, indices64, data.to(torch.float64), size=shape, device=device) + expected = torch.sparse.mm(csr_ref, B.to(torch.float64)).to(value_dtype) + elif value_dtype == torch.complex64: + csr_ref = torch.sparse_csr_tensor(indptr64, indices64, data.to(torch.complex128), size=shape, device=device) + expected = torch.sparse.mm(csr_ref, B.to(torch.complex128)).to(value_dtype) + else: + expected = torch.sparse.mm(csr_pt, B) + except Exception as exc: + pytorch_format = "COO" + pytorch_reason = f"CSR fallback: {exc}" + coo = torch.sparse_coo_tensor( + torch.stack([row_indices, indices64]), + data, + shape, + device=device, + ).coalesce() + pytorch_op = lambda: torch.sparse.mm(coo, B) + if value_dtype in (torch.float16, torch.bfloat16): + expected = torch.sparse.mm(coo.to(torch.float32), B.to(torch.float32)).to(value_dtype) + elif value_dtype == torch.float32: + expected = torch.sparse.mm(coo.to(torch.float64), B.to(torch.float64)).to(value_dtype) + elif value_dtype == torch.complex64: + expected = torch.sparse.mm(coo.to(torch.complex128), B.to(torch.complex128)).to(value_dtype) + else: + expected = torch.sparse.mm(coo, B) + + pytorch_values = expected + try: + pytorch_values, pytorch_ms = _benchmark_cuda_op( + pytorch_op, warmup=warmup, iters=iters + ) + except Exception as exc: + pytorch_reason = str(exc) if pytorch_reason is None else f"{pytorch_reason}; timing: {exc}" + + triton_metrics = _spmm_validation_metrics(triton_C, expected) + triton_match = triton_metrics["strict_allclose_match"] + + cusparse_ms = None + cusparse_match = None + cusparse_reason = None + cusparse_values = None + cusparse_metrics = None + _cupy_supported_dtypes = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ) + if run_cusparse: + if cp is None or cpx_sparse is None: + cusparse_reason = "CuPy/cuSPARSE is not available" + elif value_dtype not in _cupy_supported_dtypes: + cusparse_reason = "float16/bfloat16 not supported by CuPy sparse; skipped" + else: + try: + data_cp = _cupy_from_torch(data) + indices_cp = _cupy_from_torch(indices.to(torch.int64)) + indptr_cp = _cupy_from_torch(indptr) + B_cp = _cupy_from_torch(B) + A_csr = cpx_sparse.csr_matrix( + (data_cp, indices_cp, indptr_cp), shape=shape + ) + cusparse_values_cp, cusparse_ms = _benchmark_cuda_op( + lambda: A_csr @ B_cp, warmup=warmup, iters=iters + ) + cusparse_values = _torch_from_cupy(cusparse_values_cp) + cusparse_metrics = _spmm_validation_metrics(cusparse_values, expected) + cusparse_match = cusparse_metrics["strict_allclose_match"] + except Exception as exc: + cusparse_reason = str(exc) + + triton_speedup_vs_pytorch = ( + pytorch_ms / triton_ms if (pytorch_ms is not None and triton_ms > 0) else None + ) + triton_speedup_vs_cusparse = ( + cusparse_ms / triton_ms if (cusparse_ms is not None and triton_ms > 0) else None + ) + threshold = _spmm_relative_threshold(value_dtype) + return { + "parameters": { + "format": "csr", + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": nnz, + "n_dense_cols": n_dense_cols, + "value_dtype": str(value_dtype), + "index_dtype": str(index_dtype), + "warmup": warmup, + "iters": iters, + "block_n": launch["block_n"], + "block_nnz": launch["block_nnz"], + "max_segments": launch["max_segments"], + "required_segments": launch["required_segments"], + "alg1_warp_size": launch["warp_size"], + "alg1_factor": launch["factor"], + "auto_max_segments": launch["auto_max_segments"], + "run_cusparse": run_cusparse, + }, + "performance": { + "pytorch_ms": pytorch_ms, + "triton_ms": triton_ms, + "triton_first_call_ms": triton_first_call_ms, + "cusparse_ms": cusparse_ms, + "triton_speedup_vs_pytorch": triton_speedup_vs_pytorch, + "triton_speedup_vs_cusparse": triton_speedup_vs_cusparse, + }, + "verification": { + "triton_match_reference": triton_match, + "triton_match_pytorch": triton_match, + "triton_max_error": triton_metrics["max_abs_error"], + "triton_max_abs_error": triton_metrics["max_abs_error"], + "triton_max_relative_error": triton_metrics["max_relative_error"], + "triton_sum_relative_error": triton_metrics["sum_relative_error"], + "triton_relative_threshold": triton_metrics["relative_threshold"], + "triton_strict_allclose_match": triton_metrics["strict_allclose_match"], + "pytorch_match_reference": True, + "pytorch_max_error": 0.0, + "pytorch_max_abs_error": 0.0, + "pytorch_max_relative_error": 0.0, + "pytorch_sum_relative_error": 0.0, + "pytorch_relative_threshold": threshold, + "cusparse_match_reference": cusparse_match, + "cusparse_match_pytorch": cusparse_match, + "cusparse_max_error": (cusparse_metrics["max_abs_error"] if cusparse_metrics is not None else None), + "cusparse_max_abs_error": (cusparse_metrics["max_abs_error"] if cusparse_metrics is not None else None), + "cusparse_max_relative_error": (cusparse_metrics["max_relative_error"] if cusparse_metrics is not None else None), + "cusparse_sum_relative_error": (cusparse_metrics["sum_relative_error"] if cusparse_metrics is not None else None), + "cusparse_relative_threshold": (cusparse_metrics["relative_threshold"] if cusparse_metrics is not None else threshold), + "cusparse_strict_allclose_match": (cusparse_metrics["strict_allclose_match"] if cusparse_metrics is not None else None), + }, + "backend_status": { + "pytorch_unavailable_reason": pytorch_reason, + "pytorch_sparse_format": pytorch_format, + "cusparse_unavailable_reason": cusparse_reason, + "flagsparse_internal_route": "csr-alg1", + }, + "samples": { + "pytorch": pytorch_values, + "triton": triton_C, + "reference": expected, + "cusparse": cusparse_values, + }, + } +def comprehensive_spmm_test( + n_rows=4096, + n_cols=4096, + nnz=65536, + n_dense_cols=32, + dtype=torch.float32, + index_dtype=torch.int32, + warmup=20, + iters=200, + block_n=None, + block_nnz=None, + max_segments=None, + run_cusparse=True, +): + """Full SpMM benchmark entry for one configuration.""" + return benchmark_spmm_case( + n_rows=n_rows, + n_cols=n_cols, + nnz=nnz, + n_dense_cols=n_dense_cols, + value_dtype=dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + block_n=block_n, + block_nnz=block_nnz, + max_segments=max_segments, + run_cusparse=run_cusparse, + ) diff --git a/src/flagsparse/sparse_operations/spmv_coo.py b/src/flagsparse/sparse_operations/spmv_coo.py new file mode 100644 index 0000000..76fda9d --- /dev/null +++ b/src/flagsparse/sparse_operations/spmv_coo.py @@ -0,0 +1,348 @@ +"""COO SpMV **without CSR / indptr** (fp32/fp64). + +- ``sort_by_row=True`` (default): lex-sort (row,col), build compact **row-run** offsets + ``seg_starts`` (length #runs+1, not ``n_rows+1``), one Triton program per run — register + reduction + single ``tl.store`` per output row (no atomics on ``y``). +- ``sort_by_row=False``: grid over NNZ with ``tl.atomic_add`` (slower / contentions). + +Storage: sorted ``data, row, col`` plus optional ``seg_starts`` int32 vector — never ``indptr``. +""" + +from ._common import * + +import time + +import triton +import triton.language as tl + + +class PreparedCoo: + """Sorted COO + optional row-run bounds ``seg_starts``; no CSR indptr.""" + + __slots__ = ( + "data", + "row", + "col", + "shape", + "n_rows", + "n_cols", + "nnz", + "seg_starts", + "n_segs", + "use_seg_kernel", + ) + + def __init__(self, data, row, col, shape, seg_starts=None): + self.data = data + self.row = row + self.col = col + self.shape = (int(shape[0]), int(shape[1])) + self.n_rows, self.n_cols = self.shape + self.nnz = int(data.numel()) + self.seg_starts = seg_starts + if seg_starts is None: + self.n_segs = 0 + self.use_seg_kernel = False + else: + self.n_segs = int(seg_starts.numel()) - 1 + self.use_seg_kernel = self.n_segs > 0 + + +@triton.jit +def _spmv_coo_seg_f32( + data_ptr, + col_ptr, + row_ptr, + x_ptr, + y_ptr, + seg_starts_ptr, + n_segs, + BLOCK_INNER: tl.constexpr, +): + seg = tl.program_id(0) + if seg >= n_segs: + return + start = tl.load(seg_starts_ptr + seg) + end = tl.load(seg_starts_ptr + seg + 1) + row_id = tl.load(row_ptr + start) + acc = tl.zeros((), dtype=tl.float32) + pos = start + while pos < end: + offs = pos + tl.arange(0, BLOCK_INNER) + m = offs < end + v = tl.load(data_ptr + offs, mask=m, other=0.0) + c = tl.load(col_ptr + offs, mask=m, other=0) + xv = tl.load(x_ptr + c, mask=m, other=0.0) + acc += tl.sum(tl.where(m, v * xv, 0.0)) + pos += BLOCK_INNER + tl.store(y_ptr + row_id, acc) + + +@triton.jit +def _spmv_coo_seg_f64( + data_ptr, + col_ptr, + row_ptr, + x_ptr, + y_ptr, + seg_starts_ptr, + n_segs, + BLOCK_INNER: tl.constexpr, +): + seg = tl.program_id(0) + if seg >= n_segs: + return + start = tl.load(seg_starts_ptr + seg) + end = tl.load(seg_starts_ptr + seg + 1) + row_id = tl.load(row_ptr + start) + acc = tl.zeros((), dtype=tl.float64) + pos = start + while pos < end: + offs = pos + tl.arange(0, BLOCK_INNER) + m = offs < end + v = tl.load(data_ptr + offs, mask=m, other=0.0) + c = tl.load(col_ptr + offs, mask=m, other=0) + xv = tl.load(x_ptr + c, mask=m, other=0.0) + acc += tl.sum(tl.where(m, v * xv, 0.0)) + pos += BLOCK_INNER + tl.store(y_ptr + row_id, acc) + + +@triton.jit +def _spmv_coo_atomic_f32( + data_ptr, + row_ptr, + col_ptr, + x_ptr, + y_ptr, + nnz, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + m = offs < nnz + r = tl.load(row_ptr + offs, mask=m, other=0) + c = tl.load(col_ptr + offs, mask=m, other=0) + v = tl.load(data_ptr + offs, mask=m, other=0.0) + xv = tl.load(x_ptr + c, mask=m, other=0.0) + contrib = tl.where(m, v * xv, 0.0).to(tl.float32) + tl.atomic_add(y_ptr + r, contrib, mask=m, sem="relaxed") + + +@triton.jit +def _spmv_coo_atomic_f64( + data_ptr, + row_ptr, + col_ptr, + x_ptr, + y_ptr, + nnz, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + m = offs < nnz + r = tl.load(row_ptr + offs, mask=m, other=0) + c = tl.load(col_ptr + offs, mask=m, other=0) + v = tl.load(data_ptr + offs, mask=m, other=0.0) + xv = tl.load(x_ptr + c, mask=m, other=0.0) + contrib = tl.where(m, v * xv, 0.0).to(tl.float64) + tl.atomic_add(y_ptr + r, contrib, mask=m, sem="relaxed") + + +def _sort_coo_lex_inplace(data, row, col, n_cols): + row64 = row.to(torch.int64) + col64 = col.to(torch.int64) + if data.numel() == 0: + return data.contiguous(), row64, col64 + key = row64 * max(1, int(n_cols)) + col64 + order = torch.argsort(key) + return ( + data[order].contiguous(), + row64[order].contiguous(), + col64[order].contiguous(), + ) + + +def _seg_starts_from_sorted_rows(row_i32, nnz, device): + """Boundaries of constant-row runs in sorted COO → int32[n_runs+1].""" + if nnz == 0: + return None + diff = row_i32[1:] != row_i32[:-1] + breaks = torch.nonzero(diff, as_tuple=False).flatten().to(torch.int32) + 1 + return torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=device), + breaks, + torch.tensor([nnz], dtype=torch.int32, device=device), + ] + ) + + +def _prepare_coo_tensors(data, row, col, shape, sort_by_row): + if not all(torch.is_tensor(t) for t in (data, row, col)): + raise TypeError("data, row, col must all be torch.Tensor") + if data.ndim != 1 or row.ndim != 1 or col.ndim != 1: + raise ValueError("data, row, col must be 1D") + if not all(t.is_cuda for t in (data, row, col)): + raise ValueError("data, row, col must be CUDA tensors") + n_rows, n_cols = int(shape[0]), int(shape[1]) + if data.numel() != row.numel() or data.numel() != col.numel(): + raise ValueError("data, row, col must have the same length") + if data.dtype not in (torch.float32, torch.float64): + raise TypeError( + "this COO SpMV path supports float32/float64 only (use CSR for other dtypes)" + ) + if sort_by_row: + data, row64, col64 = _sort_coo_lex_inplace(data, row, col, n_cols) + seg = _seg_starts_from_sorted_rows( + row64.to(torch.int32), data.numel(), data.device + ) + else: + data = data.contiguous() + row64 = row.to(torch.int64).contiguous() + col64 = col.to(torch.int64).contiguous() + seg = None + if data.numel() > 0: + if int(row64.min().item()) < 0 or int(row64.max().item()) >= n_rows: + raise IndexError("row indices out of range") + if int(col64.min().item()) < 0 or int(col64.max().item()) >= n_cols: + raise IndexError("col indices out of range") + if int(row64.max().item()) > _INDEX_LIMIT_INT32 or int( + col64.max().item() + ) > _INDEX_LIMIT_INT32: + raise ValueError("indices exceed int32 Triton kernel range") + kr = row64.to(torch.int32) + kc = col64.to(torch.int32) + return data, kr, kc, seg + + +def prepare_spmv_coo(data, row, col, shape, sort_by_row=True): + """Cache sorted COO + row-run ``seg_starts`` when ``sort_by_row``. No ``indptr``.""" + d, kr, kc, seg = _prepare_coo_tensors( + data, row, col, shape, sort_by_row + ) + return PreparedCoo(d, kr, kc, shape, seg_starts=seg) + + +def _validate_x_coo(x, prepared): + if x is None or not torch.is_tensor(x): + raise TypeError("x must be a torch.Tensor") + if x.ndim != 1: + raise ValueError("x must be a 1D tensor") + if not x.is_cuda: + raise ValueError("x must be a CUDA tensor") + if x.dtype != prepared.data.dtype: + raise TypeError("x dtype must match sparse matrix dtype") + if x.numel() != prepared.n_cols: + raise ValueError( + f"x length must be n_cols={prepared.n_cols}, got {x.numel()}" + ) + if x.device != prepared.data.device: + raise ValueError("x must be on the same device as sparse matrix data") + return x.contiguous() + + +def _triton_spmv_coo_kernel( + prepared, x, block_size, num_warps, block_inner +): + dtype = prepared.data.dtype + y = torch.zeros(prepared.n_rows, dtype=dtype, device=prepared.data.device) + nnz = prepared.nnz + if nnz == 0: + return y + if prepared.use_seg_kernel: + ker = _spmv_coo_seg_f64 if dtype == torch.float64 else _spmv_coo_seg_f32 + grid = (prepared.n_segs,) + ker[grid]( + prepared.data, + prepared.col, + prepared.row, + x, + y, + prepared.seg_starts, + prepared.n_segs, + BLOCK_INNER=block_inner, + num_warps=1, + ) + return y + ker = _spmv_coo_atomic_f64 if dtype == torch.float64 else _spmv_coo_atomic_f32 + grid = (triton.cdiv(nnz, block_size),) + ker[grid]( + prepared.data, + prepared.row, + prepared.col, + x, + y, + nnz, + BLOCK=block_size, + num_warps=num_warps, + ) + return y + + +def flagsparse_spmv_coo( + data=None, + row=None, + col=None, + x=None, + shape=None, + out=None, + return_time=False, + prepared=None, + sort_by_row=True, + block_size=256, + num_warps=4, + block_inner=128, +): + """COO SpMV with no CSR indptr. See module docstring. + + ``block_inner``: tile for the row-run kernel (``sort_by_row=True``). + ``block_size`` / ``num_warps``: grid over NNZ when ``sort_by_row=False`` (atomics). + """ + if prepared is None: + if any(a is None for a in (data, row, col, x, shape)): + raise ValueError( + "data, row, col, x, shape required when prepared is None" + ) + prepared = prepare_spmv_coo( + data, row, col, shape, sort_by_row=sort_by_row + ) + else: + if x is None: + raise TypeError("x is required when prepared is set") + if shape is None: + shape = prepared.shape + sh = (int(shape[0]), int(shape[1])) + if sh != prepared.shape: + raise ValueError( + f"shape {sh} does not match prepared.shape {prepared.shape}" + ) + x = _validate_x_coo(x, prepared) + if num_warps not in (1, 2, 4, 8, 16, 32): + raise ValueError("num_warps must be a power of 2 in [1, 32]") + if block_inner <= 0 or (block_inner & (block_inner - 1)) != 0: + raise ValueError("block_inner must be a positive power of 2") + t0 = None + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + y = _triton_spmv_coo_kernel( + prepared, + x, + block_size=block_size, + num_warps=num_warps, + block_inner=block_inner, + ) + elapsed_ms = None + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if out is not None: + if out.shape != y.shape or out.dtype != y.dtype: + raise ValueError("out shape/dtype must match result") + out.copy_(y) + y = out + if return_time: + return y, elapsed_ms + return y \ No newline at end of file diff --git a/src/flagsparse/sparse_operations/spmv_csr.py b/src/flagsparse/sparse_operations/spmv_csr.py new file mode 100644 index 0000000..929d699 --- /dev/null +++ b/src/flagsparse/sparse_operations/spmv_csr.py @@ -0,0 +1,758 @@ +"""CSR SpMV: Triton baseline kernels + optimised CSR-Vector buckets.""" + +from ._common import * + +import time +import triton +import triton.language as tl + +class PreparedCsrSpmv: + """Cached CSR metadata for repeated SpMV calls on the same sparse matrix.""" + + __slots__ = ( + "data", + "kernel_indices", + "kernel_indptr", + "shape", + "n_rows", + "n_cols", + "block_nnz", + "max_segments", + "max_row_nnz", + "opt_buckets", + "supports_opt", + "_baseline_compute_dtype", + "_baseline_data", + ) + + def __init__( + self, + data, + kernel_indices, + kernel_indptr, + shape, + n_rows, + n_cols, + block_nnz, + max_segments, + max_row_nnz, + opt_buckets, + ): + self.data = data + self.kernel_indices = kernel_indices + self.kernel_indptr = kernel_indptr + self.shape = (int(shape[0]), int(shape[1])) + self.n_rows = n_rows + self.n_cols = n_cols + self.block_nnz = block_nnz + self.max_segments = max_segments + self.max_row_nnz = max_row_nnz + self.opt_buckets = opt_buckets + self.supports_opt = data.dtype in (torch.float32, torch.float64) + if data.dtype in (torch.float16, torch.bfloat16): + self._baseline_compute_dtype = torch.float32 + elif data.dtype == torch.float32: + self._baseline_compute_dtype = torch.float64 + else: + self._baseline_compute_dtype = data.dtype + self._baseline_data = None + + +# Performance-first CSR-Vector buckets. num_warps*32 >= block_size. +# First bucket uses batch_rows>1: one program processes several short rows +# (fewer blocks → better occupancy on graphs with millions of low-degree rows). +_SPMV_OPT_BUCKET_CONFIGS = ( + { + "max_row_nnz": 64, + "block_size": 32, + "num_warps": 1, + "num_stages": 2, + "batch_rows": 16, + }, + {"max_row_nnz": 512, "block_size": 256, "num_warps": 8, "num_stages": 2}, + {"max_row_nnz": 4096, "block_size": 512, "num_warps": 16, "num_stages": 2}, + {"max_row_nnz": None, "block_size": 1024, "num_warps": 32, "num_stages": 3}, +) +# fp64: extra row-length tiers + smaller tiles vs f32; batch_rows=4 for short-row kernel. +_SPMV_OPT_BUCKET_CONFIGS_FP64 = ( + { + "max_row_nnz": 64, + "block_size": 32, + "num_warps": 1, + "num_stages": 2, + "batch_rows": 4, + }, + {"max_row_nnz": 256, "block_size": 64, "num_warps": 2, "num_stages": 2}, + {"max_row_nnz": 2048, "block_size": 128, "num_warps": 4, "num_stages": 2}, + {"max_row_nnz": 8192, "block_size": 256, "num_warps": 8, "num_stages": 2}, + {"max_row_nnz": None, "block_size": 512, "num_warps": 16, "num_stages": 1}, +) +_SPMV_OPT_ACC_MODES = ("fast", "mixed", "accurate") + + + +@triton.jit +def _spmv_csr_real_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + x_ptr, + y_ptr, + n_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, +): + row = tl.program_id(0) + if row >= n_rows: + return + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + acc = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + a = tl.load(data_ptr + offsets, mask=mask, other=0.0) + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + x_vals = tl.load(x_ptr + col, mask=mask, other=0.0) + part = tl.where(mask, a * x_vals, 0.0) + acc = acc + tl.sum(part) + tl.store(y_ptr + row, acc) + + +@triton.jit +def _spmv_csr_complex_kernel( + data_ri_ptr, + indices_ptr, + indptr_ptr, + x_ri_ptr, + y_ri_ptr, + n_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, +): + row = tl.program_id(0) + if row >= n_rows: + return + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + acc_re = tl.load(data_ri_ptr + start * 2, mask=start < end, other=0.0) * 0 + acc_im = tl.load(data_ri_ptr + start * 2 + 1, mask=start < end, other=0.0) * 0 + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + x_re = tl.load(x_ri_ptr + col * 2, mask=mask, other=0.0) + x_im = tl.load(x_ri_ptr + col * 2 + 1, mask=mask, other=0.0) + prod_re = tl.where(mask, a_re * x_re - a_im * x_im, 0.0) + prod_im = tl.where(mask, a_re * x_im + a_im * x_re, 0.0) + acc_re = acc_re + tl.sum(prod_re) + acc_im = acc_im + tl.sum(prod_im) + tl.store(y_ri_ptr + row * 2, acc_re) + tl.store(y_ri_ptr + row * 2 + 1, acc_im) + + +# ── Optimised SpMV (CSR-Vector, perf-oriented, no CuPy) ───────────── +# fp32 / fp64 native lane accum. Batched kernel for many short rows per program. + +@triton.jit +def _spmv_csr_batched_short_f32( + data_ptr, + indices_ptr, + indptr_ptr, + x_ptr, + y_ptr, + rows_ptr, + n_bucket_rows, + BATCH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + MAX_SEGS: tl.constexpr, +): + pid = tl.program_id(0) + lane = tl.arange(0, BLOCK_SIZE) + for b in range(BATCH): + ridx = pid * BATCH + b + active = ridx < n_bucket_rows + row = tl.load(rows_ptr + ridx, mask=active, other=0) + start = tl.load(indptr_ptr + row, mask=active, other=0) + end = tl.load(indptr_ptr + row + 1, mask=active, other=0) + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for seg in range(MAX_SEGS): + offs = start + seg * BLOCK_SIZE + lane + mask = offs < end + a = tl.load(data_ptr + offs, mask=mask, other=0.0) + col = tl.load(indices_ptr + offs, mask=mask, other=0) + xv = tl.load(x_ptr + col, mask=mask, other=0.0) + acc += tl.where(mask, a * xv, 0.0) + tl.store(y_ptr + row, tl.sum(acc), mask=active) + +@triton.jit +def _spmv_csr_batched_short_f64( + data_ptr, + indices_ptr, + indptr_ptr, + x_ptr, + y_ptr, + rows_ptr, + n_bucket_rows, + BATCH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + MAX_SEGS: tl.constexpr, +): + pid = tl.program_id(0) + lane = tl.arange(0, BLOCK_SIZE) + for b in range(BATCH): + ridx = pid * BATCH + b + active = ridx < n_bucket_rows + row = tl.load(rows_ptr + ridx, mask=active, other=0) + start = tl.load(indptr_ptr + row, mask=active, other=0) + end = tl.load(indptr_ptr + row + 1, mask=active, other=0) + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float64) + for seg in range(MAX_SEGS): + offs = start + seg * BLOCK_SIZE + lane + mask = offs < end + a = tl.load(data_ptr + offs, mask=mask, other=0.0) + col = tl.load(indices_ptr + offs, mask=mask, other=0) + xv = tl.load(x_ptr + col, mask=mask, other=0.0) + acc += tl.where(mask, a * xv, 0.0) + tl.store(y_ptr + row, tl.sum(acc), mask=active) + +@triton.jit +def _spmv_csr_vector_rows_f32( + data_ptr, + indices_ptr, + indptr_ptr, + x_ptr, + y_ptr, + rows_ptr, + n_bucket_rows, + BLOCK_SIZE: tl.constexpr, + MAX_SEGS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_bucket_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + lane = tl.arange(0, BLOCK_SIZE) + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + for seg in range(MAX_SEGS): + offs = start + seg * BLOCK_SIZE + lane + mask = offs < end + a = tl.load(data_ptr + offs, mask=mask, other=0.0) + col = tl.load(indices_ptr + offs, mask=mask, other=0) + xv = tl.load(x_ptr + col, mask=mask, other=0.0) + acc = tl.where(mask, acc + a * xv, acc) + tl.store(y_ptr + row, tl.sum(acc)) + +@triton.jit +def _spmv_csr_vector_rows_f64( + data_ptr, + indices_ptr, + indptr_ptr, + x_ptr, + y_ptr, + rows_ptr, + n_bucket_rows, + BLOCK_SIZE: tl.constexpr, + MAX_SEGS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_bucket_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + lane = tl.arange(0, BLOCK_SIZE) + acc = tl.zeros((BLOCK_SIZE,), dtype=tl.float64) + for seg in range(MAX_SEGS): + offs = start + seg * BLOCK_SIZE + lane + mask = offs < end + a = tl.load(data_ptr + offs, mask=mask, other=0.0) + col = tl.load(indices_ptr + offs, mask=mask, other=0) + xv = tl.load(x_ptr + col, mask=mask, other=0.0) + acc = tl.where(mask, acc + a * xv, acc) + tl.store(y_ptr + row, tl.sum(acc)) + +def _build_spmv_opt_buckets( + row_lengths, + max_row_nnz, + row_index_dtype, + max_segments=None, + fp64=False, +): + buckets = [] + lower_bound = 0 + configs = _SPMV_OPT_BUCKET_CONFIGS_FP64 if fp64 else _SPMV_OPT_BUCKET_CONFIGS + for spec in configs: + upper_bound = spec["max_row_nnz"] + if upper_bound is None: + mask = row_lengths > lower_bound + bucket_max_row_nnz = max_row_nnz + elif lower_bound == 0: + # Include nnz==0 rows in the first bucket (they still need y[i]=0). + mask = row_lengths <= upper_bound + bucket_max_row_nnz = upper_bound + else: + mask = (row_lengths > lower_bound) & (row_lengths <= upper_bound) + bucket_max_row_nnz = upper_bound + rows = torch.nonzero(mask, as_tuple=False).flatten() + if rows.numel() == 0: + if upper_bound is not None: + lower_bound = upper_bound + continue + if max_segments is None: + max_segs = max( + (bucket_max_row_nnz + spec["block_size"] - 1) // spec["block_size"], + 1, + ) + else: + max_segs = max_segments + buckets.append( + { + "rows": rows.to(row_index_dtype), + "block_size": spec["block_size"], + "max_segs": max_segs, + "num_warps": spec["num_warps"], + "num_stages": spec["num_stages"], + "batch_rows": int(spec.get("batch_rows", 1)), + } + ) + if upper_bound is not None: + lower_bound = upper_bound + return buckets + +def _triton_spmv_csr_impl_opt_prepared(prepared, x): + # First bucket includes nnz==0 rows; every row gets exactly one store. + dtype = prepared.data.dtype + y = torch.empty(prepared.n_rows, dtype=dtype, device=prepared.data.device) + if prepared.n_rows == 0: + return y + vec_f32 = _spmv_csr_vector_rows_f32 + vec_f64 = _spmv_csr_vector_rows_f64 + bat_f32 = _spmv_csr_batched_short_f32 + bat_f64 = _spmv_csr_batched_short_f64 + for bucket in prepared.opt_buckets: + rows = bucket["rows"] + br = max(1, int(bucket.get("batch_rows", 1))) + n_r = rows.numel() + if br > 1: + kernel = bat_f64 if dtype == torch.float64 else bat_f32 + grid = (triton.cdiv(n_r, br),) + kernel[grid]( + prepared.data, + prepared.kernel_indices, + prepared.kernel_indptr, + x, + y, + rows, + n_bucket_rows=n_r, + BATCH=br, + BLOCK_SIZE=bucket["block_size"], + MAX_SEGS=bucket["max_segs"], + num_warps=bucket["num_warps"], + num_stages=bucket["num_stages"], + ) + else: + kernel = vec_f64 if dtype == torch.float64 else vec_f32 + grid = (n_r,) + kernel[grid]( + prepared.data, + prepared.kernel_indices, + prepared.kernel_indptr, + x, + y, + rows, + n_bucket_rows=n_r, + BLOCK_SIZE=bucket["block_size"], + MAX_SEGS=bucket["max_segs"], + num_warps=bucket["num_warps"], + num_stages=bucket["num_stages"], + ) + return y + + +def _prepare_spmv_csr_matrix(data, indices, indptr, shape): + if not all(torch.is_tensor(t) for t in (data, indices, indptr)): + raise TypeError("data, indices, indptr must all be torch.Tensor") + if data.ndim != 1 or indices.ndim != 1 or indptr.ndim != 1: + raise ValueError("data, indices, indptr must be 1D tensors") + n_rows, n_cols = int(shape[0]), int(shape[1]) + if indptr.numel() != n_rows + 1: + raise ValueError( + f"indptr length must be n_rows+1={n_rows + 1}, got {indptr.numel()}" + ) + if data.numel() != indices.numel(): + raise ValueError("data and indices must have the same length (nnz)") + if not all(t.is_cuda for t in (data, indices, indptr)): + raise ValueError("data, indices, indptr must be CUDA tensors") + if data.dtype not in SUPPORTED_VALUE_DTYPES: + raise TypeError( + "data dtype must be one of: float16, bfloat16, float32, float64, complex64, complex128" + ) + if indices.dtype not in SUPPORTED_INDEX_DTYPES: + raise TypeError("indices dtype must be torch.int32 or torch.int64") + + data = data.contiguous() + indices = indices.contiguous() + indptr = indptr.contiguous() + + if indptr.numel() > 0: + if int(indptr[0].item()) != 0: + raise ValueError("indptr must start at zero") + if int(indptr[-1].item()) != data.numel(): + raise ValueError("indptr[-1] must equal nnz") + if indptr.numel() > 1 and torch.any(indptr[1:] < indptr[:-1]).item(): + raise ValueError("indptr must be non-decreasing") + + nnz = data.numel() + if nnz > 0: + min_index = int(indices.min().item()) + max_index = int(indices.max().item()) + if min_index < 0 or max_index >= n_cols: + raise IndexError("indices out of range for n_cols") + if max_index > _INDEX_LIMIT_INT32: + raise ValueError( + f"int64 column index {max_index} exceeds Triton int32 kernel range" + ) + kernel_indices = indices.to(torch.int32) if indices.dtype == torch.int64 else indices + kernel_indptr = ( + indptr.to(torch.int32) if nnz <= _INDEX_LIMIT_INT32 else indptr.to(torch.int64) + ) + row_lengths = kernel_indptr[1:] - kernel_indptr[:-1] + max_row_nnz = int(row_lengths.max().item()) if n_rows > 0 else 0 + return ( + data, + kernel_indices, + kernel_indptr, + n_rows, + n_cols, + row_lengths, + max_row_nnz, + ) + + +def _validate_spmv_x(x, prepared): + if x is None or not torch.is_tensor(x): + raise TypeError("x must be a torch.Tensor") + if x.ndim != 1: + raise ValueError("x must be a 1D tensor") + if not x.is_cuda: + raise ValueError("x must be a CUDA tensor") + if x.dtype != prepared.data.dtype: + raise TypeError("x dtype must match sparse matrix dtype") + if x.numel() != prepared.n_cols: + raise ValueError(f"x length must be n_cols={prepared.n_cols}, got {x.numel()}") + if x.device != prepared.data.device: + raise ValueError("x must be on the same device as sparse matrix data") + return x.contiguous() + + +def prepare_spmv_csr(data, indices, indptr, shape, block_nnz=256, max_segments=None): + ( + data, + kernel_indices, + kernel_indptr, + n_rows, + n_cols, + row_lengths, + max_row_nnz, + ) = _prepare_spmv_csr_matrix(data, indices, indptr, shape) + block_nnz_use = block_nnz + if max_segments is None: + max_segments_use = max((max_row_nnz + block_nnz_use - 1) // block_nnz_use, 1) + while max_segments_use > 2048 and block_nnz_use < 65536: + block_nnz_use *= 2 + max_segments_use = max( + (max_row_nnz + block_nnz_use - 1) // block_nnz_use, + 1, + ) + else: + max_segments_use = max_segments + row_index_dtype = torch.int32 if n_rows <= _INDEX_LIMIT_INT32 else torch.int64 + opt_buckets = _build_spmv_opt_buckets( + row_lengths, + max_row_nnz=max_row_nnz, + row_index_dtype=row_index_dtype, + max_segments=max_segments, + fp64=data.dtype == torch.float64, + ) + return PreparedCsrSpmv( + data=data, + kernel_indices=kernel_indices, + kernel_indptr=kernel_indptr, + shape=shape, + n_rows=n_rows, + n_cols=n_cols, + block_nnz=block_nnz_use, + max_segments=max_segments_use, + max_row_nnz=max_row_nnz, + opt_buckets=opt_buckets, + ) + + +def _get_spmv_baseline_data(prepared): + compute_dtype = prepared._baseline_compute_dtype + if compute_dtype == prepared.data.dtype: + return compute_dtype, prepared.data + if ( + prepared._baseline_data is None + or prepared._baseline_data.dtype != compute_dtype + ): + prepared._baseline_data = prepared.data.to(compute_dtype) + return compute_dtype, prepared._baseline_data + + +def _triton_spmv_csr_impl_prepared(prepared, x): + device = prepared.data.device + dtype = prepared.data.dtype + y = torch.empty(prepared.n_rows, dtype=dtype, device=device) + if prepared.n_rows == 0: + return y + compute_dtype, data_in = _get_spmv_baseline_data(prepared) + x_in = x + if compute_dtype != dtype: + x_in = x.to(compute_dtype) + if not _is_complex_dtype(compute_dtype): + y_out = torch.empty(prepared.n_rows, dtype=compute_dtype, device=device) + grid = (prepared.n_rows,) + _spmv_csr_real_kernel[grid]( + data_in, + prepared.kernel_indices, + prepared.kernel_indptr, + x_in, + y_out, + n_rows=prepared.n_rows, + BLOCK_NNZ=prepared.block_nnz, + MAX_SEGMENTS=prepared.max_segments, + ) + if dtype != compute_dtype: + y_out = y_out.to(dtype) + y.copy_(y_out) + return y + data_ri = torch.view_as_real(data_in).reshape(-1) + x_ri = torch.view_as_real(x_in).reshape(-1) + comp_dtype = data_ri.dtype + y_ri = torch.empty(prepared.n_rows * 2, dtype=comp_dtype, device=device) + grid = (prepared.n_rows,) + _spmv_csr_complex_kernel[grid]( + data_ri, + prepared.kernel_indices, + prepared.kernel_indptr, + x_ri, + y_ri, + n_rows=prepared.n_rows, + BLOCK_NNZ=prepared.block_nnz, + MAX_SEGMENTS=prepared.max_segments, + ) + y_ri = y_ri.reshape(prepared.n_rows, 2) + y.copy_(torch.view_as_complex(y_ri)) + return y + + +def flagsparse_spmv_csr( + data=None, + indices=None, + indptr=None, + x=None, + shape=None, + block_nnz=256, + max_segments=None, + out=None, + return_time=False, + use_opt=False, + prepared=None, +): + """ + CSR SpMV: y = A @ x using Triton (cuSPARSE-aligned dtypes). + data, indices, indptr: CSR arrays; x: dense vector; shape: (n_rows, n_cols). + prepared: cached CSR metadata from prepare_spmv_csr for steady-state runs. + max_segments: None = auto-compute from indptr so all NNZ per row are covered. + use_opt: if True, use the faster CSR-Vector bucketed path (fp32/fp64 native accum). + """ + if prepared is None: + if any(arg is None for arg in (data, indices, indptr, shape)): + raise ValueError( + "data, indices, indptr, and shape are required when prepared is not provided" + ) + prepared = prepare_spmv_csr( + data, + indices, + indptr, + shape, + block_nnz=block_nnz, + max_segments=max_segments, + ) + x = _validate_spmv_x(x, prepared) + t0 = None + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + if use_opt and prepared.supports_opt: + y = _triton_spmv_csr_impl_opt_prepared(prepared, x) + else: + y = _triton_spmv_csr_impl_prepared(prepared, x) + elapsed_ms = None + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if out is not None: + if out.shape != y.shape or out.dtype != y.dtype: + raise ValueError("out shape/dtype must match result") + out.copy_(y) + y = out + if return_time: + return y, elapsed_ms + return y + + +def _coo_is_sorted_lex(row_i64, col_i64, n_cols): + """True iff COO rows are non-decreasing lex order (row, col).""" + n = row_i64.numel() + if n <= 1: + return True + scale = max(1, int(n_cols)) + key = row_i64 * scale + col_i64 + return bool((key[1:] >= key[:-1]).all().item()) + + +def coo_to_csr_for_spmv(data, row, col, shape, assume_sorted=False): + """Convert COO to CSR triple (data, csr_col_indices, indptr) for SpMV.""" + n_rows, n_cols = int(shape[0]), int(shape[1]) + row64 = row.to(torch.int64) + col64 = col.to(torch.int64) + if row64.numel() == 0: + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + return data, col64.to(torch.int32), indptr + + if assume_sorted or _coo_is_sorted_lex(row64, col64, n_cols): + row_s, col_s, data_s = row64, col64, data + else: + key = row64 * max(1, n_cols) + col64 + order = torch.argsort(key) + row_s = row64[order] + col_s = col64[order] + data_s = data[order].to(data.dtype) + + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + nnz = data_s.numel() + if nnz > 0: + nnz_per_row = torch.bincount(row_s, minlength=n_rows) + indptr[1:] = torch.cumsum(nnz_per_row, dim=0) + indices = col_s.to(torch.int32) + return data_s, indices, indptr + + +def prepare_spmv_coo_tocsr( + data, + row, + col, + shape, + block_nnz=256, + max_segments=None, + assume_sorted=False, +): + """One-time COO → CSR + bucket metadata; use with ``flagsparse_spmv_coo_tocsr(..., prepared=p)``.""" + if not all(torch.is_tensor(t) for t in (data, row, col)): + raise TypeError("data, row, col must all be torch.Tensor") + if not all(t.is_cuda for t in (data, row, col)): + raise ValueError("data, row, col must all be CUDA tensors") + if data.ndim != 1 or row.ndim != 1 or col.ndim != 1: + raise ValueError("data, row, col must all be 1D tensors") + if data.dtype not in SUPPORTED_VALUE_DTYPES: + raise TypeError( + "data dtype must be one of: float16, bfloat16, float32, float64, complex64, complex128" + ) + n_rows, n_cols = int(shape[0]), int(shape[1]) + if row.numel() != col.numel() or data.numel() != row.numel(): + raise ValueError("data, row, col must have the same length") + + data_s, indices, indptr = coo_to_csr_for_spmv( + data, row, col, shape, assume_sorted=assume_sorted + ) + return prepare_spmv_csr( + data_s, + indices, + indptr, + shape, + block_nnz=block_nnz, + max_segments=max_segments, + ) + + +def flagsparse_spmv_coo_tocsr( + data=None, + row=None, + col=None, + x=None, + shape=None, + block_nnz=256, + max_segments=None, + out=None, + return_time=False, + use_opt=True, + prepared=None, + assume_sorted=False, +): + """COO SpMV via CSR conversion: y = A @ x. + + Default ``use_opt=True`` enables the fast CSR-Vector path for float32/float64. + If COO is already lex-sorted by (row, col), pass ``assume_sorted=True`` to skip ``argsort``. + + Steady-state: ``p = prepare_spmv_coo_tocsr(data, row, col, shape)`` then call with ``prepared=p`` + (``data``/``row``/``col`` may be omitted). + """ + if prepared is not None: + if x is None: + raise TypeError("x is required") + if shape is None: + shape = prepared.shape + sh = (int(shape[0]), int(shape[1])) + if sh != prepared.shape: + raise ValueError(f"shape {sh} does not match prepared.shape {prepared.shape}") + return flagsparse_spmv_csr( + x=x, + shape=shape, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + return_time=return_time, + use_opt=use_opt, + prepared=prepared, + ) + + if not all(torch.is_tensor(t) for t in (data, row, col, x)): + raise TypeError("data, row, col, x must all be torch.Tensor") + if not all(t.is_cuda for t in (data, row, col, x)): + raise ValueError("data, row, col, x must all be CUDA tensors") + if data.ndim != 1 or row.ndim != 1 or col.ndim != 1 or x.ndim != 1: + raise ValueError("data, row, col, x must all be 1D tensors") + + n_rows, n_cols = int(shape[0]), int(shape[1]) + if data.dtype not in SUPPORTED_VALUE_DTYPES: + raise TypeError( + "data dtype must be one of: float16, bfloat16, float32, float64, complex64, complex128" + ) + if x.dtype != data.dtype: + raise TypeError("x dtype must match data dtype") + + data_s, indices, indptr = coo_to_csr_for_spmv( + data, row, col, shape, assume_sorted=assume_sorted + ) + + return flagsparse_spmv_csr( + data_s, + indices, + indptr, + x, + shape, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + return_time=return_time, + use_opt=use_opt, + ) diff --git a/src/flagsparse/sparse_operations/spsm.py b/src/flagsparse/sparse_operations/spsm.py new file mode 100644 index 0000000..28e26d3 --- /dev/null +++ b/src/flagsparse/sparse_operations/spsm.py @@ -0,0 +1,794 @@ +"""Sparse triangular matrix-matrix solve (SpSM) for CSR/COO.""" + +from ._common import * + +import time + +try: + from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular +except Exception: + cpx_spsolve_triangular = None + +SUPPORTED_SPSM_VALUE_DTYPES = (torch.float32, torch.float64) +SUPPORTED_SPSM_INDEX_DTYPES = (torch.int32, torch.int64) +SPSM_NON_TRANS_PRIMARY_COMBOS = ( + ("csr", torch.float32, torch.int32), + ("csr", torch.float64, torch.int32), + ("coo", torch.float32, torch.int32), + ("coo", torch.float64, torch.int32), +) + + +def _is_non_transpose(op): + op_str = str(op).upper() + return op_str in ("NON_TRANS", "NON_TRANSPOSE", "N") + + +def _validate_spsm_non_trans_combo(fmt_name, value_dtype, index_dtype): + combo = (str(fmt_name).lower(), value_dtype, index_dtype) + if combo not in SPSM_NON_TRANS_PRIMARY_COMBOS: + raise TypeError( + f"{fmt_name} SpSM currently supports NON_TRANS combinations with int32 kernel indices: " + "(float32, int32), (float64, int32)" + ) + + +def _validate_spsm_op_and_layout(opA, opB, major): + if not _is_non_transpose(opA): + raise NotImplementedError("Only op(A)=NON_TRANS is supported") + if not _is_non_transpose(opB): + raise NotImplementedError("Only op(B)=NON_TRANS is supported") + if str(major).lower() != "row": + raise NotImplementedError("Only row-major dense layout is supported") + + +def _prepare_spsm_csr_inputs(data, indices, indptr, B, shape, opA, opB, major): + if not all(torch.is_tensor(t) for t in (data, indices, indptr, B)): + raise TypeError("data, indices, indptr, B must all be torch.Tensor") + if not all(t.is_cuda for t in (data, indices, indptr, B)): + raise ValueError("data, indices, indptr, B must all be CUDA tensors") + if data.ndim != 1 or indices.ndim != 1 or indptr.ndim != 1: + raise ValueError("data, indices, indptr must be 1D") + if B.ndim != 2: + raise ValueError("B must be a 2D dense matrix") + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows != n_cols: + raise ValueError(f"SpSM requires square A, got shape={shape}") + if indptr.numel() != n_rows + 1: + raise ValueError(f"indptr length must be n_rows+1={n_rows + 1}") + if data.numel() != indices.numel(): + raise ValueError("data and indices must have the same length (nnz)") + if B.shape[0] != n_rows: + raise ValueError(f"B.shape[0] must equal n_rows={n_rows}") + _validate_spsm_op_and_layout(opA, opB, major) + if data.dtype not in SUPPORTED_SPSM_VALUE_DTYPES: + raise TypeError("data dtype must be float32 or float64") + if B.dtype != data.dtype: + raise TypeError("B dtype must match data dtype") + if indices.dtype not in SUPPORTED_SPSM_INDEX_DTYPES: + raise TypeError("indices dtype must be torch.int32 or torch.int64") + if indptr.dtype not in SUPPORTED_SPSM_INDEX_DTYPES: + raise TypeError("indptr dtype must be torch.int32 or torch.int64") + if data.numel() > 0: + if int(indices.to(torch.int64).max().item()) > _INDEX_LIMIT_INT32: + raise ValueError("index value exceeds int32 kernel range") + _validate_spsm_non_trans_combo("csr", data.dtype, torch.int32) + return ( + data.contiguous(), + indices.to(torch.int64).contiguous(), + indptr.to(torch.int64).contiguous(), + B.contiguous(), + n_rows, + ) + + +def _coo_to_csr_sorted_unique(data, row, col, n_rows, n_cols): + if data.numel() == 0: + return ( + data, + torch.empty(0, dtype=torch.int64, device=data.device), + torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device), + ) + key = row * max(1, n_cols) + col + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + key_s = key[order] + data_s = data[order] + unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) + out_nnz = unique_key.numel() + data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) + data_u.scatter_add_(0, inverse, data_s) + row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") + col_u = unique_key - row_u * max(1, n_cols) + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + if out_nnz > 0: + nnz_per_row = torch.bincount(row_u, minlength=n_rows) + indptr[1:] = torch.cumsum(nnz_per_row, dim=0) + return data_u, col_u.to(torch.int64), indptr + + +def _prepare_spsm_coo_inputs(data, row, col, B, shape, opA, opB, major): + if not all(torch.is_tensor(t) for t in (data, row, col, B)): + raise TypeError("data, row, col, B must all be torch.Tensor") + if not all(t.is_cuda for t in (data, row, col, B)): + raise ValueError("data, row, col, B must all be CUDA tensors") + if data.ndim != 1 or row.ndim != 1 or col.ndim != 1: + raise ValueError("data, row, col must be 1D") + if data.numel() != row.numel() or data.numel() != col.numel(): + raise ValueError("data, row, col must have same length") + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows != n_cols: + raise ValueError(f"SpSM requires square A, got shape={shape}") + if B.ndim != 2 or B.shape[0] != n_rows: + raise ValueError("B must be 2D and B.shape[0] == n_rows") + _validate_spsm_op_and_layout(opA, opB, major) + if data.dtype not in SUPPORTED_SPSM_VALUE_DTYPES: + raise TypeError("data dtype must be float32 or float64") + if B.dtype != data.dtype: + raise TypeError("B dtype must match data dtype") + if row.dtype not in SUPPORTED_SPSM_INDEX_DTYPES or col.dtype not in SUPPORTED_SPSM_INDEX_DTYPES: + raise TypeError("row/col dtype must be torch.int32 or torch.int64") + row64 = row.to(torch.int64).contiguous() + col64 = col.to(torch.int64).contiguous() + if row64.numel() > 0: + if bool(torch.any(row64 < 0).item()) or bool(torch.any(col64 < 0).item()): + raise IndexError("row/col must be non-negative") + if int(row64.max().item()) >= n_rows: + raise IndexError("row index out of range") + if int(col64.max().item()) >= n_cols: + raise IndexError("col index out of range") + _validate_spsm_non_trans_combo("coo", data.dtype, torch.int32) + return data.contiguous(), row64, col64, B.contiguous(), n_rows, n_cols + + + + +@triton.jit +def _spsm_csr_level_kernel_real( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + x_ptr, + rows_ptr, + n_level_rows, + n_rhs, + stride_b0, + stride_x0, + BLOCK_NNZ: tl.constexpr, + BLOCK_RHS: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_rhs = tl.program_id(1) + if pid_row >= n_level_rows: + return + + row = tl.load(rows_ptr + pid_row) + rhs_offsets = pid_rhs * BLOCK_RHS + tl.arange(0, BLOCK_RHS) + rhs_mask = rhs_offsets < n_rhs + + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + + if USE_FP64_ACC: + acc = tl.zeros((BLOCK_RHS,), dtype=tl.float64) + diag = tl.zeros((1,), dtype=tl.float64) + else: + acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + diag = tl.zeros((1,), dtype=tl.float32) + + if UNIT_DIAG: + diag = diag + 1.0 + + for seg in range(MAX_SEGMENTS): + nnz_offsets = start + seg * BLOCK_NNZ + tl.arange(0, BLOCK_NNZ) + nnz_mask = nnz_offsets < end + a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) + col = tl.load(indices_ptr + nnz_offsets, mask=nnz_mask, other=0) + + if USE_FP64_ACC: + a = a.to(tl.float64) + else: + a = a.to(tl.float32) + + solved_mask = col < row if LOWER else col > row + diag_mask = col == row + + x_ptrs = x_ptr + col[:, None] * stride_x0 + rhs_offsets[None, :] + x_mask = nnz_mask[:, None] & rhs_mask[None, :] + x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0) + if USE_FP64_ACC: + x_vals = x_vals.to(tl.float64) + else: + x_vals = x_vals.to(tl.float32) + + contrib = tl.where((nnz_mask & solved_mask)[:, None], a[:, None] * x_vals, 0.0) + acc += tl.sum(contrib, axis=0) + + if not UNIT_DIAG: + diag += tl.sum(tl.where(nnz_mask & diag_mask, a, 0.0), axis=0) + + b_ptrs = b_ptr + row * stride_b0 + rhs_offsets + rhs = tl.load(b_ptrs, mask=rhs_mask, other=0.0) + rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) + + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + out = (rhs - acc) / diag_safe + out = tl.where(out == out, out, 0.0) + + out_ptrs = x_ptr + row * stride_x0 + rhs_offsets + tl.store(out_ptrs, out, mask=rhs_mask) + + +@triton.jit +def _spsm_coo_level_kernel_real( + data_ptr, + row_ptr_ptr, + col_ptr, + b_ptr, + x_ptr, + rows_ptr, + n_level_rows, + n_rhs, + stride_b0, + stride_x0, + BLOCK_NNZ: tl.constexpr, + BLOCK_RHS: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_rhs = tl.program_id(1) + if pid_row >= n_level_rows: + return + + row = tl.load(rows_ptr + pid_row) + rhs_offsets = pid_rhs * BLOCK_RHS + tl.arange(0, BLOCK_RHS) + rhs_mask = rhs_offsets < n_rhs + + start = tl.load(row_ptr_ptr + row) + end = tl.load(row_ptr_ptr + row + 1) + + if USE_FP64_ACC: + acc = tl.zeros((BLOCK_RHS,), dtype=tl.float64) + diag = tl.zeros((1,), dtype=tl.float64) + else: + acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + diag = tl.zeros((1,), dtype=tl.float32) + + if UNIT_DIAG: + diag = diag + 1.0 + + for seg in range(MAX_SEGMENTS): + nnz_offsets = start + seg * BLOCK_NNZ + tl.arange(0, BLOCK_NNZ) + nnz_mask = nnz_offsets < end + a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) + col = tl.load(col_ptr + nnz_offsets, mask=nnz_mask, other=0) + + if USE_FP64_ACC: + a = a.to(tl.float64) + else: + a = a.to(tl.float32) + + solved_mask = col < row if LOWER else col > row + diag_mask = col == row + + x_ptrs = x_ptr + col[:, None] * stride_x0 + rhs_offsets[None, :] + x_mask = nnz_mask[:, None] & rhs_mask[None, :] + x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0) + if USE_FP64_ACC: + x_vals = x_vals.to(tl.float64) + else: + x_vals = x_vals.to(tl.float32) + + contrib = tl.where((nnz_mask & solved_mask)[:, None], a[:, None] * x_vals, 0.0) + acc += tl.sum(contrib, axis=0) + + if not UNIT_DIAG: + diag += tl.sum(tl.where(nnz_mask & diag_mask, a, 0.0), axis=0) + + b_ptrs = b_ptr + row * stride_b0 + rhs_offsets + rhs = tl.load(b_ptrs, mask=rhs_mask, other=0.0) + rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) + + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + out = (rhs - acc) / diag_safe + out = tl.where(out == out, out, 0.0) + + out_ptrs = x_ptr + row * stride_x0 + rhs_offsets + tl.store(out_ptrs, out, mask=rhs_mask) + + +def _build_spsm_levels(indptr, indices, n_rows, lower=True): + if n_rows == 0: + return [] + indptr_h = indptr.to(torch.int64).cpu() + indices_h = indices.to(torch.int64).cpu() + levels = [0] * n_rows + + if lower: + for i in range(n_rows): + s = int(indptr_h[i].item()) + e = int(indptr_h[i + 1].item()) + lvl = 0 + for p in range(s, e): + c = int(indices_h[p].item()) + if c < i: + lvl = max(lvl, levels[c] + 1) + levels[i] = lvl + else: + for i in range(n_rows - 1, -1, -1): + s = int(indptr_h[i].item()) + e = int(indptr_h[i + 1].item()) + lvl = 0 + for p in range(s, e): + c = int(indices_h[p].item()) + if c > i: + lvl = max(lvl, levels[c] + 1) + levels[i] = lvl + + max_level = max(levels) + buckets = [[] for _ in range(max_level + 1)] + for r, lv in enumerate(levels): + buckets[lv].append(r) + + device = indptr.device + return [ + torch.tensor(rows, dtype=torch.int32, device=device) + for rows in buckets + if rows + ] + + +def _auto_spsm_launch_config(indptr, block_nnz=None, max_segments=None): + if indptr.numel() <= 1: + max_nnz_per_row = 0 + else: + row_lengths = indptr[1:] - indptr[:-1] + max_nnz_per_row = int(row_lengths.max().item()) + + auto_block = block_nnz is None + if block_nnz is None: + if max_nnz_per_row <= 64: + block_nnz_use = 64 + elif max_nnz_per_row <= 256: + block_nnz_use = 128 + elif max_nnz_per_row <= 1024: + block_nnz_use = 256 + elif max_nnz_per_row <= 4096: + block_nnz_use = 512 + else: + block_nnz_use = 1024 + else: + block_nnz_use = int(block_nnz) + if block_nnz_use <= 0: + raise ValueError("block_nnz must be positive") + + required_segments = max((max_nnz_per_row + block_nnz_use - 1) // block_nnz_use, 1) + if max_segments is None: + max_segments_use = required_segments + if auto_block: + while max_segments_use > 2048 and block_nnz_use < 65536: + block_nnz_use *= 2 + max_segments_use = max((max_nnz_per_row + block_nnz_use - 1) // block_nnz_use, 1) + else: + max_segments_use = int(max_segments) + if max_segments_use <= 0: + raise ValueError("max_segments must be positive") + if max_segments_use < required_segments: + raise ValueError( + f"max_segments={max_segments_use} is too small; at least {required_segments} required" + ) + return block_nnz_use, max_segments_use + + +def _auto_rhs_block(n_rhs): + n_rhs = int(n_rhs) + if n_rhs <= 8: + return 8 + if n_rhs <= 16: + return 16 + if n_rhs <= 32: + return 32 + return 64 + + +def _coo_sort_unique_and_rowptr(data, row64, col64, n_rows, n_cols): + if data.numel() == 0: + row_ptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + return data, torch.empty(0, dtype=torch.int64, device=data.device), row_ptr + + key = row64 * max(1, n_cols) + col64 + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + key_s = key[order] + data_s = data[order] + + unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) + out_nnz = unique_key.numel() + data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) + data_u.scatter_add_(0, inverse, data_s) + + row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") + col_u = unique_key - row_u * max(1, n_cols) + row_ptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + nnz_per_row = torch.bincount(row_u, minlength=n_rows) + row_ptr[1:] = torch.cumsum(nnz_per_row, dim=0) + return data_u, col_u.to(torch.int64), row_ptr + + +def _run_spsm_csr_core( + data, + indices64, + indptr64, + rhs, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + block_rhs=None, +): + if rhs.ndim != 2: + raise ValueError("rhs must be 2D") + rhs = rhs.contiguous() + if rhs.shape[0] != n_rows: + raise ValueError("rhs first dim must equal n_rows") + n_rhs = int(rhs.shape[1]) + x = torch.zeros_like(rhs) + if n_rows == 0 or n_rhs == 0: + return x + + indices32 = indices64.to(torch.int32) + levels = _build_spsm_levels(indptr64, indices32, n_rows, lower=lower) + block_nnz_use, max_segments_use = _auto_spsm_launch_config( + indptr64, block_nnz=block_nnz, max_segments=max_segments + ) + block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) + if block_rhs_use <= 0: + raise ValueError("block_rhs must be positive") + + use_fp64 = data.dtype == torch.float64 + diag_eps = 1e-12 if use_fp64 else 1e-6 + + for rows_lv in levels: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use)) + _spsm_csr_level_kernel_real[grid]( + data, + indices32, + indptr64, + rhs, + x, + rows_lv, + n_level_rows=n_lv, + n_rhs=n_rhs, + stride_b0=rhs.stride(0), + stride_x0=x.stride(0), + BLOCK_NNZ=block_nnz_use, + BLOCK_RHS=block_rhs_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + return x + + +def _run_spsm_coo_core( + data, + row64, + col64, + rhs, + n_rows, + n_cols, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + block_rhs=None, +): + if rhs.ndim != 2: + raise ValueError("rhs must be 2D") + rhs = rhs.contiguous() + if rhs.shape[0] != n_rows: + raise ValueError("rhs first dim must equal n_rows") + n_rhs = int(rhs.shape[1]) + x = torch.zeros_like(rhs) + if n_rows == 0 or n_rhs == 0: + return x + + data_u, col_u64, row_ptr = _coo_sort_unique_and_rowptr(data, row64, col64, n_rows, n_cols) + cols32 = col_u64.to(torch.int32) + levels = _build_spsm_levels(row_ptr, cols32, n_rows, lower=lower) + block_nnz_use, max_segments_use = _auto_spsm_launch_config( + row_ptr, block_nnz=block_nnz, max_segments=max_segments + ) + block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) + if block_rhs_use <= 0: + raise ValueError("block_rhs must be positive") + + use_fp64 = data.dtype == torch.float64 + diag_eps = 1e-12 if use_fp64 else 1e-6 + + for rows_lv in levels: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use)) + _spsm_coo_level_kernel_real[grid]( + data_u, + row_ptr, + cols32, + rhs, + x, + rows_lv, + n_level_rows=n_lv, + n_rhs=n_rhs, + stride_b0=rhs.stride(0), + stride_x0=x.stride(0), + BLOCK_NNZ=block_nnz_use, + BLOCK_RHS=block_rhs_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + return x + + +def flagsparse_spsm_csr( + data, + indices, + indptr, + B, + shape, + alpha=1.0, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + out=None, + return_time=False, +): + data, indices, indptr, B, _ = _prepare_spsm_csr_inputs( + data, indices, indptr, B, shape, opA, opB, major + ) + alpha_t = torch.as_tensor(alpha, dtype=B.dtype, device=B.device) + rhs = alpha_t * B + torch.cuda.synchronize() + t0 = time.perf_counter() + x = _run_spsm_csr_core( + data, + indices, + indptr, + rhs, + int(shape[0]), + lower=lower, + unit_diagonal=unit_diagonal, + ) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if out is not None: + if out.shape != x.shape or out.dtype != x.dtype: + raise ValueError("out shape/dtype must match result") + out.copy_(x) + x = out + if return_time: + return x, elapsed_ms + return x + + +def flagsparse_spsm_coo( + data, + row, + col, + B, + shape, + alpha=1.0, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + out=None, + return_time=False, +): + data, row64, col64, B, n_rows, n_cols = _prepare_spsm_coo_inputs( + data, row, col, B, shape, opA, opB, major + ) + alpha_t = torch.as_tensor(alpha, dtype=B.dtype, device=B.device) + rhs = alpha_t * B + torch.cuda.synchronize() + t0 = time.perf_counter() + x = _run_spsm_coo_core( + data, + row64, + col64, + rhs, + n_rows, + n_cols, + lower=lower, + unit_diagonal=unit_diagonal, + ) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if out is not None: + if out.shape != x.shape or out.dtype != x.dtype: + raise ValueError("out shape/dtype must match result") + out.copy_(x) + x = out + if return_time: + return x, elapsed_ms + return x + + +def _cupy_spsm_baseline_from_csr( + data, indices, indptr, B, shape, alpha, lower, unit_diagonal, warmup=10, iters=50 +): + if cp is None or cpx_sparse is None or cpx_spsolve_triangular is None: + return None, None, "cupy/cusparse unavailable" + try: + data_cp = _cupy_from_torch(data) + idx_cp = _cupy_from_torch(indices.to(torch.int64)) + ptr_cp = _cupy_from_torch(indptr) + B_cp = _cupy_from_torch(B) + A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + alpha_cp = cp.asarray(alpha, dtype=B_cp.dtype) + for _ in range(max(0, int(warmup))): + _ = cpx_spsolve_triangular( + A_cp, alpha_cp * B_cp, lower=lower, unit_diagonal=unit_diagonal + ) + cp.cuda.runtime.deviceSynchronize() + e0 = cp.cuda.Event() + e1 = cp.cuda.Event() + e0.record() + for _ in range(max(1, int(iters))): + C_cp = cpx_spsolve_triangular( + A_cp, alpha_cp * B_cp, lower=lower, unit_diagonal=unit_diagonal + ) + e1.record() + e1.synchronize() + ms = cp.cuda.get_elapsed_time(e0, e1) / max(1, int(iters)) + return _torch_from_cupy(C_cp).to(B.dtype), ms, None + except Exception as exc: + return None, None, str(exc) + + +def benchmark_spsm_case( + fmt="csr", + n_rows=1024, + n_rhs=32, + nnz=8192, + value_dtype=torch.float32, + index_dtype=torch.int32, + alpha=1.0, + lower=True, + unit_diagonal=False, + warmup=10, + iters=50, +): + """Benchmark SpSM (NON_TRANS/NON_TRANS, row-major B) against cuSPARSE baseline.""" + device = torch.device("cuda") + data, indices, indptr = _build_random_csr( + n_rows, n_rows, nnz, value_dtype, index_dtype, device + ) + # Make A triangular and diagonally dominant. + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr.to(torch.int64)[1:] - indptr.to(torch.int64)[:-1], + ) + col_ids = indices.to(torch.int64) + tri_mask = (col_ids <= row_ids) if lower else (col_ids >= row_ids) + if tri_mask.numel() > 0: + data = data[tri_mask] + col_ids = col_ids[tri_mask] + row_ids = row_ids[tri_mask] + data, col_ids, indptr = _coo_to_csr_sorted_unique(data, row_ids, col_ids, n_rows, n_rows) + # Ensure diagonal exists without densifying A. + row = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + diag_mask = col_ids == row + diag_present = torch.zeros(n_rows, dtype=torch.bool, device=device) + if diag_mask.numel() > 0 and bool(torch.any(diag_mask).item()): + diag_present[row[diag_mask]] = True + missing_diag = torch.nonzero(~diag_present, as_tuple=False).reshape(-1).to(torch.int64) + if missing_diag.numel() > 0: + diag_data = torch.ones(missing_diag.numel(), dtype=value_dtype, device=device) + data = torch.cat([data, diag_data], dim=0) + row = torch.cat([row, missing_diag], dim=0) + col = torch.cat([col_ids, missing_diag], dim=0) + data, col_ids, indptr = _coo_to_csr_sorted_unique(data, row, col, n_rows, n_rows) + row = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + col = col_ids.to(torch.int64) + B = torch.randn((n_rows, n_rhs), dtype=value_dtype, device=device).contiguous() + shape = (n_rows, n_rows) + + if str(fmt).lower() == "coo": + triton_op = lambda: flagsparse_spsm_coo( + data, + row, + col, + B, + shape, + alpha=alpha, + lower=lower, + unit_diagonal=unit_diagonal, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + ) + else: + triton_op = lambda: flagsparse_spsm_csr( + data, + col_ids.to(index_dtype), + indptr, + B, + shape, + alpha=alpha, + lower=lower, + unit_diagonal=unit_diagonal, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + ) + C_fs, fs_ms = _benchmark_cuda_op(triton_op, warmup=warmup, iters=iters) + atol, rtol = _tolerance_for_dtype(value_dtype) + + C_cu, cu_ms, cu_reason = _cupy_spsm_baseline_from_csr( + data, col_ids, indptr, B, shape, alpha, lower, unit_diagonal, warmup=warmup, iters=iters + ) + cu_ok = None + cu_err = None + if C_cu is not None: + cu_ok = torch.allclose(C_fs, C_cu, atol=atol, rtol=rtol) + cu_err = float(torch.max(torch.abs(C_fs - C_cu)).item()) if C_fs.numel() > 0 else 0.0 + + return { + "parameters": { + "format": str(fmt).lower(), + "n_rows": n_rows, + "n_rhs": n_rhs, + "nnz": int(data.numel()), + "value_dtype": str(value_dtype), + "index_dtype": str(index_dtype), + "opA": "NON_TRANS", + "opB": "NON_TRANS", + "major": "row", + }, + "performance": { + "flagsparse_ms": fs_ms, + "cusparse_ms": cu_ms, + "speedup_vs_cusparse": (cu_ms / fs_ms if (cu_ms is not None and fs_ms > 0) else None), + }, + "verification": { + "flagsparse_match_cusparse": cu_ok, + "flagsparse_vs_cusparse_max_error": cu_err, + }, + "backend_status": { + "cusparse_unavailable_reason": cu_reason, + }, + "samples": {"flagsparse": C_fs, "cusparse": C_cu}, + } \ No newline at end of file diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py new file mode 100644 index 0000000..a4952e5 --- /dev/null +++ b/src/flagsparse/sparse_operations/spsv.py @@ -0,0 +1,1658 @@ +"""Sparse triangular solve (SpSV) CSR/COO.""" + +from ._common import * + +from collections import OrderedDict +import os +import time +import triton +import triton.language as tl + +SUPPORTED_SPSV_VALUE_DTYPES = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, +) +SUPPORTED_SPSV_INDEX_DTYPES = (torch.int32, torch.int64) +SPSV_NON_TRANS_SUPPORTED_COMBOS = ( + (torch.float32, torch.int32), + (torch.float64, torch.int32), + (torch.complex64, torch.int32), + (torch.complex128, torch.int32), + (torch.float32, torch.int64), + (torch.float64, torch.int64), + (torch.complex64, torch.int64), + (torch.complex128, torch.int64), +) +SPSV_TRANS_SUPPORTED_COMBOS = ( + (torch.float32, torch.int32), + (torch.float64, torch.int32), + (torch.complex64, torch.int32), + (torch.complex128, torch.int32), + (torch.float32, torch.int64), + (torch.float64, torch.int64), + (torch.complex64, torch.int64), + (torch.complex128, torch.int64), +) +def _spsv_env_flag(name, default="0"): + return str(os.environ.get(name, default)).lower() in ("1", "true", "yes", "on") + + +SPSV_PROMOTE_FP32_TO_FP64 = _spsv_env_flag("FLAGSPARSE_SPSV_PROMOTE_FP32_TO_FP64", "0") +SPSV_PROMOTE_TRANSPOSE_FP32_TO_FP64 = _spsv_env_flag( + "FLAGSPARSE_SPSV_PROMOTE_TRANSPOSE_FP32_TO_FP64", "0" +) +SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128 = _spsv_env_flag( + "FLAGSPARSE_SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128", "0" +) +_SPSV_CSR_PREPROCESS_CACHE = OrderedDict() +_SPSV_CSR_PREPROCESS_CACHE_SIZE = 8 + + +def _clear_spsv_csr_preprocess_cache(): + _SPSV_CSR_PREPROCESS_CACHE.clear() + + +def _csr_to_dense(data, indices, indptr, shape): + """Convert CSR (torch CUDA tensors) to dense matrix on the same device.""" + device = data.device + dtype = data.dtype + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows == 0 or n_cols == 0: + return torch.zeros((n_rows, n_cols), dtype=dtype, device=device) + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + col_ind = indices.to(torch.int64) + coo = torch.sparse_coo_tensor( + torch.stack([row_ind, col_ind]), + data, + (n_rows, n_cols), + device=device, + ).coalesce() + return coo.to_dense() + + +def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): + """Validate NON_TRANS support matrix and keep error messages explicit.""" + if (data_dtype, index_dtype) in SPSV_NON_TRANS_SUPPORTED_COMBOS: + return + raise TypeError( + f"{fmt_name} SpSV currently supports NON_TRANS combinations: " + "(float32, int32/int64), (float64, int32/int64), " + "(complex64, int32/int64), (complex128, int32/int64)" + ) + + +def _validate_spsv_trans_combo(data_dtype, index_dtype, fmt_name): + if (data_dtype, index_dtype) in SPSV_TRANS_SUPPORTED_COMBOS: + return + raise TypeError( + f"{fmt_name} SpSV currently supports TRANS/CONJ combinations: " + "(float32, int32/int64), (float64, int32/int64), " + "(complex64, int32/int64), (complex128, int32/int64)" + ) + + +def _normalize_spsv_transpose_mode(transpose): + if isinstance(transpose, bool): + return "T" if transpose else "N" + token = str(transpose).strip().upper() + if token in ("N", "NON", "NON_TRANS"): + return "N" + if token in ("T", "TRANS"): + return "T" + if token in ("C", "H", "CONJ", "CONJ_TRANS", "CONJUGATE_TRANSPOSE"): + return "C" + raise ValueError( + "transpose must be bool or one of: " + "N/NON/NON_TRANS, T/TRANS, C/H/CONJ/CONJ_TRANS/CONJUGATE_TRANSPOSE" + ) + + +def _prepare_spsv_inputs(data, indices, indptr, b, shape): + """Validate and normalize inputs for sparse solve A x = b with CSR A.""" + if not all(torch.is_tensor(t) for t in (data, indices, indptr, b)): + raise TypeError("data, indices, indptr, b must all be torch.Tensor") + if not all(t.is_cuda for t in (data, indices, indptr, b)): + raise ValueError("data, indices, indptr, b must all be CUDA tensors") + if data.ndim != 1 or indices.ndim != 1 or indptr.ndim != 1: + raise ValueError("data, indices, indptr must be 1D") + if b.ndim not in (1, 2): + raise ValueError("b must be 1D or 2D (vector or multiple RHS)") + + n_rows, n_cols = int(shape[0]), int(shape[1]) + if indptr.numel() != n_rows + 1: + raise ValueError(f"indptr length must be n_rows+1={n_rows + 1}") + if data.numel() != indices.numel(): + raise ValueError("data and indices must have the same length (nnz)") + if b.ndim == 1 and b.numel() != n_rows: + raise ValueError(f"b length must equal n_rows={n_rows}") + if b.ndim == 2 and b.shape[0] != n_rows: + raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") + + if data.dtype not in SUPPORTED_SPSV_VALUE_DTYPES: + raise TypeError( + "data dtype must be one of: float32, float64, complex64, complex128" + ) + if indices.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: + raise TypeError("indices dtype must be torch.int32 or torch.int64") + if indptr.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: + raise TypeError("indptr dtype must be torch.int32 or torch.int64") + if b.dtype != data.dtype: + raise TypeError("b dtype must match data dtype") + + indices64 = indices.to(torch.int64).contiguous() + indptr64 = indptr.to(torch.int64).contiguous() + if indices64.numel() > 0 and int(indices64.max().item()) > _INDEX_LIMIT_INT32: + raise ValueError( + f"int64 index value {int(indices64.max().item())} exceeds Triton int32 kernel range" + ) + if indptr64.numel() > 0: + if int(indptr64[0].item()) != 0: + raise ValueError("indptr[0] must be 0") + if int(indptr64[-1].item()) != data.numel(): + raise ValueError("indptr[-1] must equal nnz") + if bool(torch.any(indptr64[1:] < indptr64[:-1]).item()): + raise ValueError("indptr must be non-decreasing") + if indices64.numel() > 0: + if bool(torch.any(indices64 < 0).item()): + raise IndexError("indices must be non-negative") + max_idx = int(indices64.max().item()) + if max_idx >= n_cols: + raise IndexError(f"indices out of range for n_cols={n_cols}") + + return ( + data.contiguous(), + indices.dtype, + indices64, + indptr64, + b.contiguous(), + n_rows, + n_cols, + ) + + +def _prepare_spsv_working_inputs(data, b): + return data, b, None + + +def _restore_spsv_output(x, target_dtype): + return x.to(target_dtype) + + +def _spsv_diag_eps_for_dtype(value_dtype): + return 1e-12 if value_dtype in (torch.float64, torch.complex128) else 1e-6 + + +def _tensor_cache_token(tensor): + try: + storage_ptr = int(tensor.untyped_storage().data_ptr()) + except Exception: + storage_ptr = 0 + return ( + str(tensor.device), + str(tensor.dtype), + tuple(int(v) for v in tensor.shape), + int(tensor.numel()), + storage_ptr, + int(getattr(tensor, "_version", 0)), + ) + + +def _spsv_cache_get(cache, key): + value = cache.get(key) + if value is not None: + cache.move_to_end(key) + return value + + +def _spsv_cache_put(cache, key, value, max_entries): + cache[key] = value + cache.move_to_end(key) + while len(cache) > max_entries: + cache.popitem(last=False) + + +def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode): + return ( + "csr_preprocess", + trans_mode, + bool(lower), + int(shape[0]), + int(shape[1]), + _tensor_cache_token(data), + _tensor_cache_token(indices), + _tensor_cache_token(indptr), + ) + + +def _build_spsv_frontiers(indptr, indices, levels, lower=True): + """Greedily merge rows from adjacent levels when they do not depend on the + currently active frontier. + + This keeps the same correctness contract as strict level scheduling while + trimming some kernel launches on matrices with narrow but not fully + serialized dependency wavefronts. + """ + if not levels: + return [] + + indptr_h = indptr.to(torch.int64).cpu() + indices_h = indices.to(torch.int64).cpu() + device = indptr.device + frontier_rows = [] + frontier_row_set = set() + merged = [] + + def _flush_frontier(): + nonlocal frontier_rows, frontier_row_set + if frontier_rows: + merged.append(torch.tensor(frontier_rows, dtype=torch.int32, device=device)) + frontier_rows = [] + frontier_row_set = set() + + for rows_lv in levels: + for row in rows_lv.to(torch.int64).cpu().tolist(): + start = int(indptr_h[row].item()) + end = int(indptr_h[row + 1].item()) + depends_on_frontier = False + for p in range(start, end): + col = int(indices_h[p].item()) + if lower: + is_dep = col < row + else: + is_dep = col > row + if is_dep and col in frontier_row_set: + depends_on_frontier = True + break + if depends_on_frontier: + _flush_frontier() + frontier_rows.append(int(row)) + frontier_row_set.add(int(row)) + _flush_frontier() + return merged + + +def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode): + if trans_mode == "N": + levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) + return { + "solve_kind": "csr_levels", + "kernel_data": data, + "kernel_indices64": indices64, + "kernel_indptr64": indptr64, + "lower_eff": lower, + "launch_groups": _build_spsv_frontiers( + indptr64, indices64, levels, lower=lower + ), + "transpose_conjugate": False, + } + + levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) + return { + "solve_kind": "transpose_push", + "kernel_data": data, + "kernel_indices64": indices64, + "kernel_indptr64": indptr64, + "lower_eff": lower, + "launch_groups": list(reversed(levels)), + "transpose_conjugate": trans_mode == "C", + } + + +def _resolve_spsv_csr_runtime( + data, + indices, + indptr, + b, + shape, + lower, + transpose, +): + input_data = data + input_indices = indices + input_indptr = indptr + trans_mode = _normalize_spsv_transpose_mode(transpose) + data, input_index_dtype, indices, indptr, b, n_rows, n_cols = _prepare_spsv_inputs( + data, indices, indptr, b, shape + ) + original_output_dtype = None + data, b, original_output_dtype = _prepare_spsv_working_inputs(data, b) + if n_rows != n_cols: + raise ValueError(f"A must be square, got shape={shape}") + if trans_mode == "N": + _validate_spsv_non_trans_combo(data.dtype, input_index_dtype, "CSR") + else: + _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") + + preprocess_key = _csr_preprocess_cache_key( + input_data, input_indices, input_indptr, (n_rows, n_cols), lower, trans_mode + ) + cached = _spsv_cache_get(_SPSV_CSR_PREPROCESS_CACHE, preprocess_key) + if cached is None: + cached = _prepare_spsv_csr_system( + data, + indices, + indptr, + n_rows, + n_cols, + lower, + trans_mode, + ) + _spsv_cache_put( + _SPSV_CSR_PREPROCESS_CACHE, + preprocess_key, + cached, + _SPSV_CSR_PREPROCESS_CACHE_SIZE, + ) + return ( + data, + b, + original_output_dtype, + trans_mode, + n_rows, + n_cols, + cached, + ) + + +@triton.jit +def _spsv_csr_level_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + b_ptr, + x_ptr, + rows_ptr, + n_level_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_level_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + acc = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 + diag = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 + if UNIT_DIAG: + diag = diag + 1.0 + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + a = tl.load(data_ptr + offsets, mask=mask, other=0.0) + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + x_vals = tl.load(x_ptr + col, mask=mask, other=0.0) + + if LOWER: + solved = col < row + else: + solved = col > row + is_diag = col == row + + acc = acc + tl.sum(tl.where(mask & solved, a * x_vals, 0.0)) + if not UNIT_DIAG: + diag = diag + tl.sum(tl.where(mask & is_diag, a, 0.0)) + + rhs = tl.load(b_ptr + row) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + x_row = (rhs - acc) / diag_safe + # Prevent NaN propagation in ill-conditioned rows. + x_row = tl.where(x_row == x_row, x_row, 0.0) + tl.store(x_ptr + row, x_row) + + +@triton.jit +def _spsv_csr_level_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + b_ri_ptr, + x_ri_ptr, + rows_ptr, + n_level_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_level_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + + if USE_FP64_ACC: + acc_re = tl.zeros((1,), dtype=tl.float64) + acc_im = tl.zeros((1,), dtype=tl.float64) + diag_re = tl.zeros((1,), dtype=tl.float64) + diag_im = tl.zeros((1,), dtype=tl.float64) + else: + acc_re = tl.zeros((1,), dtype=tl.float32) + acc_im = tl.zeros((1,), dtype=tl.float32) + diag_re = tl.zeros((1,), dtype=tl.float32) + diag_im = tl.zeros((1,), dtype=tl.float32) + + if UNIT_DIAG: + diag_re = diag_re + 1.0 + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + x_re = tl.load(x_ri_ptr + col * 2, mask=mask, other=0.0) + x_im = tl.load(x_ri_ptr + col * 2 + 1, mask=mask, other=0.0) + + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + + if LOWER: + solved = col < row + else: + solved = col > row + is_diag = col == row + + prod_re = a_re * x_re - a_im * x_im + prod_im = a_re * x_im + a_im * x_re + acc_re = acc_re + tl.sum(tl.where(mask & solved, prod_re, 0.0)) + acc_im = acc_im + tl.sum(tl.where(mask & solved, prod_im, 0.0)) + + if not UNIT_DIAG: + diag_re = diag_re + tl.sum(tl.where(mask & is_diag, a_re, 0.0)) + diag_im = diag_im + tl.sum(tl.where(mask & is_diag, a_im, 0.0)) + + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + + num_re = rhs_re - acc_re + num_im = rhs_im - acc_im + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + + x_re_out = (num_re * diag_re + num_im * diag_im) / den_safe + x_im_out = (num_im * diag_re - num_re * diag_im) / den_safe + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + + offs1 = tl.arange(0, 1) + tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) + tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + + +@triton.jit +def _spsv_csr_transpose_push_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + residual_ptr, + x_ptr, + rows_ptr, + n_level_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_level_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + + diag = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 + if UNIT_DIAG: + diag = diag + 1.0 + else: + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + a = tl.load(data_ptr + offsets, mask=mask, other=0.0) + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + is_diag = col == row + diag = diag + tl.sum(tl.where(mask & is_diag, a, 0.0)) + + rhs = tl.load(residual_ptr + row) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + x_row = rhs / diag_safe + x_row = tl.where(x_row == x_row, x_row, 0.0) + tl.store(x_ptr + row, x_row) + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + a = tl.load(data_ptr + offsets, mask=mask, other=0.0) + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + if LOWER: + target_mask = mask & (col < row) + else: + target_mask = mask & (col > row) + tl.atomic_add(residual_ptr + col, -a * x_row, mask=target_mask) + + +@triton.jit +def _spsv_csr_transpose_push_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + residual_ri_ptr, + x_ri_ptr, + rows_ptr, + n_level_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + CONJ_TRANS: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_level_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + + if USE_FP64_ACC: + diag_re = tl.zeros((1,), dtype=tl.float64) + diag_im = tl.zeros((1,), dtype=tl.float64) + else: + diag_re = tl.zeros((1,), dtype=tl.float32) + diag_im = tl.zeros((1,), dtype=tl.float32) + + if UNIT_DIAG: + diag_re = diag_re + 1.0 + else: + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + if CONJ_TRANS: + a_im = -a_im + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + is_diag = col == row + diag_re = diag_re + tl.sum(tl.where(mask & is_diag, a_re, 0.0)) + diag_im = diag_im + tl.sum(tl.where(mask & is_diag, a_im, 0.0)) + + rhs_re = tl.load(residual_ri_ptr + row * 2) + rhs_im = tl.load(residual_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + x_re_out = (rhs_re * diag_re + rhs_im * diag_im) / den_safe + x_im_out = (rhs_im * diag_re - rhs_re * diag_im) / den_safe + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + + offs1 = tl.arange(0, 1) + tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) + tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + if CONJ_TRANS: + a_im = -a_im + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + if LOWER: + target_mask = mask & (col < row) + else: + target_mask = mask & (col > row) + prod_re = a_re * x_re_out - a_im * x_im_out + prod_im = a_re * x_im_out + a_im * x_re_out + tl.atomic_add(residual_ri_ptr + col * 2, -prod_re, mask=target_mask) + tl.atomic_add(residual_ri_ptr + col * 2 + 1, -prod_im, mask=target_mask) + + +@triton.jit +def _spsv_coo_level_kernel_real( + data_ptr, + row_ptr_ptr, + col_ptr, + b_ptr, + x_ptr, + rows_ptr, + n_level_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + pid = tl.program_id(0) + if pid >= n_level_rows: + return + row = tl.load(rows_ptr + pid) + start = tl.load(row_ptr_ptr + row) + end = tl.load(row_ptr_ptr + row + 1) + acc = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 + diag = tl.load(data_ptr + start, mask=start < end, other=0.0) * 0 + if UNIT_DIAG: + diag = diag + 1.0 + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + a = tl.load(data_ptr + offsets, mask=mask, other=0.0) + col = tl.load(col_ptr + offsets, mask=mask, other=0) + x_vals = tl.load(x_ptr + col, mask=mask, other=0.0) + + if LOWER: + solved = col < row + else: + solved = col > row + is_diag = col == row + + acc = acc + tl.sum(tl.where(mask & solved, a * x_vals, 0.0)) + if not UNIT_DIAG: + diag = diag + tl.sum(tl.where(mask & is_diag, a, 0.0)) + + rhs = tl.load(b_ptr + row) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + x_row = (rhs - acc) / diag_safe + x_row = tl.where(x_row == x_row, x_row, 0.0) + tl.store(x_ptr + row, x_row) + + +def _build_spsv_levels(indptr, indices, n_rows, lower=True): + """Build dependency levels for triangular solve so each level can run in parallel.""" + if n_rows == 0: + return [] + indptr_h = indptr.to(torch.int64).cpu() + indices_h = indices.to(torch.int64).cpu() + levels = [0] * n_rows + if lower: + for i in range(n_rows): + s = int(indptr_h[i].item()) + e = int(indptr_h[i + 1].item()) + lvl = 0 + for p in range(s, e): + c = int(indices_h[p].item()) + if c < i: + lvl = max(lvl, levels[c] + 1) + levels[i] = lvl + else: + for i in range(n_rows - 1, -1, -1): + s = int(indptr_h[i].item()) + e = int(indptr_h[i + 1].item()) + lvl = 0 + for p in range(s, e): + c = int(indices_h[p].item()) + if c > i: + lvl = max(lvl, levels[c] + 1) + levels[i] = lvl + + max_level = max(levels) + buckets = [[] for _ in range(max_level + 1)] + for r, lv in enumerate(levels): + buckets[lv].append(r) + + device = indptr.device + return [ + torch.tensor(rows, dtype=torch.int32, device=device) + for rows in buckets + if rows + ] + + +def _auto_spsv_launch_config(indptr, block_nnz=None, max_segments=None): + if indptr.numel() <= 1: + max_nnz_per_row = 0 + else: + row_lengths = indptr[1:] - indptr[:-1] + max_nnz_per_row = int(row_lengths.max().item()) + + auto_block = block_nnz is None + if block_nnz is None: + if max_nnz_per_row <= 64: + block_nnz_use = 64 + elif max_nnz_per_row <= 256: + block_nnz_use = 128 + elif max_nnz_per_row <= 1024: + block_nnz_use = 256 + elif max_nnz_per_row <= 4096: + block_nnz_use = 512 + elif max_nnz_per_row <= 16384: + block_nnz_use = 1024 + else: + block_nnz_use = 2048 + else: + block_nnz_use = int(block_nnz) + if block_nnz_use <= 0: + raise ValueError("block_nnz must be a positive integer") + + required_segments = max( + (max_nnz_per_row + block_nnz_use - 1) // block_nnz_use, 1 + ) + if max_segments is None: + max_segments_use = required_segments + if auto_block: + while max_segments_use > 2048 and block_nnz_use < 65536: + block_nnz_use *= 2 + max_segments_use = max( + (max_nnz_per_row + block_nnz_use - 1) // block_nnz_use, 1 + ) + else: + max_segments_use = int(max_segments) + if max_segments_use <= 0: + raise ValueError("max_segments must be a positive integer") + if max_segments_use < required_segments: + raise ValueError( + f"max_segments={max_segments_use} is too small; at least {required_segments} required" + ) + + return block_nnz_use, max_segments_use + + +def _triton_spsv_csr_vector( + data, + indices, + indptr, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + levels=None, + block_nnz_use=None, + max_segments_use=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + if levels is None: + levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + for rows_lv in levels: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv,) + _spsv_csr_level_kernel[grid]( + data, + indices, + indptr, + b_vec, + x, + rows_lv, + n_level_rows=n_lv, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + DIAG_EPS=diag_eps, + ) + return x + + +def _triton_spsv_csr_vector_complex( + data, + indices, + indptr, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + levels=None, + block_nnz_use=None, + max_segments_use=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + if levels is None: + levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + # Some PyTorch builds return CSR values with a non-strided layout wrapper. + # Materialize a plain 1D strided buffer before splitting into real/imag parts. + if data.layout != torch.strided: + data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) + data_strided.copy_(data) + else: + data_strided = data.contiguous() + + data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() + b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() + component_dtype = _component_dtype_for_complex(data.dtype) + use_fp64 = component_dtype == torch.float64 + if component_dtype == torch.float16: + x_ri_work = torch.zeros((n_rows, 2), dtype=torch.float32, device=b_vec.device) + x_ri = x_ri_work.reshape(-1).contiguous() + else: + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + + for rows_lv in levels: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv,) + _spsv_csr_level_kernel_complex[grid]( + data_ri, + indices, + indptr, + b_ri, + x_ri, + rows_lv, + n_level_rows=n_lv, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + if component_dtype == torch.float16: + return torch.view_as_complex(x_ri_work.contiguous()) + return x + + +def _triton_spsv_csr_transpose_push_vector( + data, + indices, + indptr, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + launch_groups=None, + block_nnz_use=None, + max_segments_use=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + residual = b_vec.clone() + if launch_groups is None: + levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) + launch_groups = list(reversed(levels)) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + for rows_lv in launch_groups: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv,) + _spsv_csr_transpose_push_kernel[grid]( + data, + indices, + indptr, + residual, + x, + rows_lv, + n_level_rows=n_lv, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + DIAG_EPS=diag_eps, + ) + return x + + +def _triton_spsv_csr_transpose_push_vector_complex( + data, + indices, + indptr, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + conjugate=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + launch_groups=None, + block_nnz_use=None, + max_segments_use=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + if launch_groups is None: + levels = _build_spsv_levels(indptr, indices, n_rows, lower=lower) + launch_groups = list(reversed(levels)) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + if data.layout != torch.strided: + data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) + data_strided.copy_(data) + else: + data_strided = data.contiguous() + + residual_work = b_vec.contiguous().clone() + data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() + residual_ri = torch.view_as_real(residual_work).reshape(-1).contiguous() + component_dtype = _component_dtype_for_complex(data.dtype) + use_fp64 = component_dtype == torch.float64 + if component_dtype == torch.float16: + x_ri_work = torch.zeros((n_rows, 2), dtype=torch.float32, device=b_vec.device) + x_ri = x_ri_work.reshape(-1).contiguous() + else: + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + + for rows_lv in launch_groups: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv,) + _spsv_csr_transpose_push_kernel_complex[grid]( + data_ri, + indices, + indptr, + residual_ri, + x_ri, + rows_lv, + n_level_rows=n_lv, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + CONJ_TRANS=conjugate, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + if component_dtype == torch.float16: + return torch.view_as_complex(x_ri_work.contiguous()) + return x + + +def _choose_transpose_family_launch_config(indptr, block_nnz=None, max_segments=None): + if block_nnz is not None or max_segments is not None: + return _auto_spsv_launch_config(indptr, block_nnz=block_nnz, max_segments=max_segments) + + if indptr.numel() <= 1: + return 32, 1 + max_nnz_per_row = int((indptr[1:] - indptr[:-1]).max().item()) + for cand in (32, 64, 128, 256, 512, 1024): + req = max((max_nnz_per_row + cand - 1) // cand, 1) + if req <= 2048: + return cand, req + cand = 2048 + req = max((max_nnz_per_row + cand - 1) // cand, 1) + return cand, req + + +def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): + if not all(torch.is_tensor(t) for t in (data, row, col, b)): + raise TypeError("data, row, col, b must all be torch.Tensor") + if not all(t.is_cuda for t in (data, row, col, b)): + raise ValueError("data, row, col, b must all be CUDA tensors") + if data.ndim != 1 or row.ndim != 1 or col.ndim != 1: + raise ValueError("data, row, col must be 1D") + if row.numel() != data.numel() or col.numel() != data.numel(): + raise ValueError("data, row, col must have the same length") + if b.ndim not in (1, 2): + raise ValueError("b must be 1D or 2D (vector or multiple RHS)") + + n_rows, n_cols = int(shape[0]), int(shape[1]) + if b.ndim == 1 and b.numel() != n_rows: + raise ValueError(f"b length must equal n_rows={n_rows}") + if b.ndim == 2 and b.shape[0] != n_rows: + raise ValueError(f"b.shape[0] must equal n_rows={n_rows}") + + if data.dtype not in ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ): + raise TypeError( + "data dtype must be one of: float32, float64, complex64, complex128" + ) + if b.dtype != data.dtype: + raise TypeError("b dtype must match data dtype") + if row.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: + raise TypeError("row dtype must be torch.int32 or torch.int64") + if col.dtype not in SUPPORTED_SPSV_INDEX_DTYPES: + raise TypeError("col dtype must be torch.int32 or torch.int64") + row64 = row.to(torch.int64).contiguous() + col64 = col.to(torch.int64).contiguous() + if col64.numel() > 0 and int(col64.max().item()) > _INDEX_LIMIT_INT32: + raise ValueError( + f"int64 index value {int(col64.max().item())} exceeds Triton int32 kernel range" + ) + if row64.numel() > 0: + if bool(torch.any(row64 < 0).item()): + raise IndexError("row indices must be non-negative") + if bool(torch.any(col64 < 0).item()): + raise IndexError("col indices must be non-negative") + max_row = int(row64.max().item()) + max_col = int(col64.max().item()) + if max_row >= n_rows: + raise IndexError(f"row indices out of range for n_rows={n_rows}") + if max_col >= n_cols: + raise IndexError(f"col indices out of range for n_cols={n_cols}") + + _validate_spsv_non_trans_combo(data.dtype, torch.int32, "COO") + return ( + data.contiguous(), + row64, + col64, + b.contiguous(), + n_rows, + n_cols, + ) + + +def _csr_transpose(data, indices64, indptr64, n_rows, n_cols, conjugate=False): + if data.numel() == 0: + out_data = data + out_indices = torch.empty(0, dtype=torch.int64, device=data.device) + out_indptr = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) + return out_data, out_indices, out_indptr + + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + new_row = indices64 + new_col = row_ids + data_eff = data.conj() if conjugate and torch.is_complex(data) else data + data_t, indices_t, indptr_t = _coo_to_csr_sorted_unique( + data_eff, new_row, new_col, n_cols, n_rows + ) + return data_t, indices_t, indptr_t + + +def _coo_is_sorted_unique(row64, col64, n_cols): + nnz = row64.numel() + if nnz <= 1: + return True + key = row64 * max(1, n_cols) + col64 + is_sorted = bool(torch.all(key[1:] >= key[:-1]).item()) + is_unique = bool(torch.all(key[1:] != key[:-1]).item()) + return is_sorted and is_unique + + +def _build_coo_row_ptr(row_sorted, n_rows): + row_ptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=row_sorted.device) + if row_sorted.numel() > 0: + nnz_per_row = torch.bincount(row_sorted, minlength=n_rows) + row_ptr[1:] = torch.cumsum(nnz_per_row, dim=0) + return row_ptr + + +def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): + nnz = data.numel() + if nnz == 0: + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + indices = torch.empty(0, dtype=torch.int64, device=data.device) + return data, indices, indptr + + key = row64 * max(1, n_cols) + col64 + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + key_s = key[order] + data_s = data[order] + + unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) + out_nnz = unique_key.numel() + + data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) + data_u.scatter_add_(0, inverse, data_s) + + row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") + col_u = unique_key - row_u * max(1, n_cols) + indptr = _build_coo_row_ptr(row_u, n_rows) + indices = col_u.to(torch.int64) + return data_u, indices, indptr + + +def _triton_spsv_coo_vector( + data, + cols, + row_ptr, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + levels=None, + block_nnz_use=None, + max_segments_use=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + if levels is None: + levels = _build_spsv_levels(row_ptr, cols, n_rows, lower=lower) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + row_ptr, block_nnz=block_nnz, max_segments=max_segments + ) + + for rows_lv in levels: + n_lv = rows_lv.numel() + if n_lv == 0: + continue + grid = (n_lv,) + _spsv_coo_level_kernel_real[grid]( + data, + row_ptr, + cols, + b_vec, + x, + rows_lv, + n_level_rows=n_lv, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + DIAG_EPS=diag_eps, + ) + return x + + +def flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=True, + unit_diagonal=False, + transpose=False, + block_nnz=None, + max_segments=None, + out=None, + return_time=False, +): + """Sparse triangular solve using Triton CSR kernels. + + Primary support matrix: + - NON_TRANS: float32/float64/complex64/complex128 with int32/int64 indices + - TRANS/CONJ: float32/float64/complex64/complex128 with int32/int64 indices + """ + ( + data, + b, + original_output_dtype, + trans_mode, + n_rows, + n_cols, + solve_plan, + ) = _resolve_spsv_csr_runtime( + data, + indices, + indptr, + b, + shape, + lower, + transpose, + ) + + solve_kind = solve_plan["solve_kind"] + kernel_data = solve_plan["kernel_data"] + kernel_indices64 = solve_plan["kernel_indices64"] + kernel_indptr64 = solve_plan["kernel_indptr64"] + lower_eff = solve_plan["lower_eff"] + launch_groups = solve_plan["launch_groups"] + transpose_conjugate = solve_plan["transpose_conjugate"] + kernel_indices = ( + kernel_indices64.to(torch.int32) + if kernel_indices64.dtype != torch.int32 + else kernel_indices64 + ) + kernel_indptr = kernel_indptr64 + compute_dtype = data.dtype + data_in = kernel_data + b_in = b + if ( + data.dtype == torch.complex64 + and trans_mode in ("T", "C") + and SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128 + ): + compute_dtype = torch.complex128 + data_in = kernel_data.to(torch.complex128) + b_in = b.to(torch.complex128) + elif data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: + compute_dtype = torch.float64 + data_in = kernel_data.to(torch.float64) + b_in = b.to(torch.float64) + elif ( + data.dtype == torch.float32 + and trans_mode in ("T", "C") + and SPSV_PROMOTE_TRANSPOSE_FP32_TO_FP64 + ): + compute_dtype = torch.float64 + data_in = kernel_data.to(torch.float64) + b_in = b.to(torch.float64) + + if solve_kind == "transpose_push": + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) + vec_real = _triton_spsv_csr_transpose_push_vector + vec_complex = _triton_spsv_csr_transpose_push_vector_complex + else: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) + vec_real = _triton_spsv_csr_vector + vec_complex = _triton_spsv_csr_vector_complex + diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) + + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + if b_in.ndim == 1: + if torch.is_complex(data_in): + if solve_kind == "transpose_push": + x = vec_complex( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + conjugate=transpose_conjugate, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + launch_groups=launch_groups, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + else: + x = vec_complex( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=launch_groups, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + else: + if solve_kind == "transpose_push": + x = vec_real( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + launch_groups=launch_groups, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + else: + x = vec_real( + data_in, + kernel_indices, + kernel_indptr, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=launch_groups, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + else: + cols = [] + for j in range(b_in.shape[1]): + bj = b_in[:, j].contiguous() + if torch.is_complex(data_in): + if solve_kind == "transpose_push": + cols.append( + vec_complex( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + conjugate=transpose_conjugate, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + launch_groups=launch_groups, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + ) + else: + cols.append( + vec_complex( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=launch_groups, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + ) + else: + if solve_kind == "transpose_push": + cols.append( + vec_real( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + launch_groups=launch_groups, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + ) + else: + cols.append( + vec_real( + data_in, + kernel_indices, + kernel_indptr, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=launch_groups, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + ) + x = torch.stack(cols, dim=1) + target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype + if x.dtype != target_dtype: + x = _restore_spsv_output(x, target_dtype) + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if out is not None: + if out.shape != x.shape or out.dtype != x.dtype: + raise ValueError("out shape/dtype must match result") + out.copy_(x) + x = out + + if return_time: + return x, elapsed_ms + return x + + +def _analyze_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=True, + unit_diagonal=False, + transpose=False, + clear_cache=False, + return_time=False, +): + del unit_diagonal + if clear_cache: + _clear_spsv_csr_preprocess_cache() + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + _resolve_spsv_csr_runtime( + data, + indices, + indptr, + b, + shape, + lower, + transpose, + ) + if return_time: + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1000.0 + + +def flagsparse_spsv_coo( + + data, + row, + col, + b, + shape, + lower=True, + unit_diagonal=False, + transpose=False, + coo_mode="auto", + block_nnz=None, + max_segments=None, + out=None, + return_time=False, +): + """COO SpSV with dual mode: + - direct: use COO level kernel directly (requires sorted+unique COO) + - csr: convert COO -> CSR (sorted+deduplicated) then call flagsparse_spsv_csr + - auto: pick direct when sorted+unique and supported, otherwise csr + + Notes: + - direct mode currently supports only non-transposed real-valued inputs + - complex dtypes and TRANS/CONJ always route through the CSR implementation + """ + data, row64, col64, b, n_rows, n_cols = _prepare_spsv_coo_inputs( + data, row, col, b, shape, transpose=transpose + ) + if n_rows != n_cols: + raise ValueError(f"A must be square, got shape={shape}") + + mode = str(coo_mode).lower() + if mode not in ("auto", "direct", "csr"): + raise ValueError("coo_mode must be one of: 'auto', 'direct', 'csr'") + + sorted_unique = _coo_is_sorted_unique(row64, col64, n_cols) + trans_mode = _normalize_spsv_transpose_mode(transpose) + direct_supported = (trans_mode == "N") and (not torch.is_complex(data)) + use_direct = direct_supported and (mode == "direct" or (mode == "auto" and sorted_unique)) + if mode == "direct" and not direct_supported: + raise ValueError( + "coo_mode='direct' supports only non-transposed real-valued inputs; " + "use coo_mode='csr' or 'auto' for TRANS/CONJ or complex dtypes" + ) + if mode == "direct" and not sorted_unique: + raise ValueError( + "coo_mode='direct' requires COO sorted by (row, col) with no duplicate coordinates; " + "use coo_mode='csr' or 'auto' for unsorted/duplicate COO input" + ) + + if not use_direct: + data_csr, indices_csr, indptr_csr = _coo_to_csr_sorted_unique( + data, row64, col64, n_rows, n_cols + ) + return flagsparse_spsv_csr( + data_csr, + indices_csr, + indptr_csr, + b, + shape, + lower=lower, + unit_diagonal=unit_diagonal, + transpose=transpose, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + return_time=return_time, + ) + + kernel_cols = col64.to(torch.int32) + row_ptr = _build_coo_row_ptr(row64, n_rows) + + compute_dtype = data.dtype + data_in = data + b_in = b + if data.dtype == torch.float32 and SPSV_PROMOTE_FP32_TO_FP64: + compute_dtype = torch.float64 + data_in = data.to(torch.float64) + b_in = b.to(torch.float64) + levels = _build_spsv_levels(row_ptr, kernel_cols, n_rows, lower=lower) + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + row_ptr, block_nnz=block_nnz, max_segments=max_segments + ) + diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) + + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + if b_in.ndim == 1: + x = _triton_spsv_coo_vector( + data_in, + kernel_cols, + row_ptr, + b_in, + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + else: + cols_out = [] + for j in range(b_in.shape[1]): + cols_out.append( + _triton_spsv_coo_vector( + data_in, + kernel_cols, + row_ptr, + b_in[:, j].contiguous(), + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + levels=levels, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + ) + ) + x = torch.stack(cols_out, dim=1) + if compute_dtype != data.dtype: + x = x.to(data.dtype) + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + + if out is not None: + if out.shape != x.shape or out.dtype != x.dtype: + raise ValueError("out shape/dtype must match result") + out.copy_(x) + x = out + + if return_time: + return x, elapsed_ms + return x diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..23f0718 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package root (allows ``tests.pytest`` imports for parametrized suites).""" diff --git a/tests/diagnose_spmm_opt.py b/tests/diagnose_spmm_opt.py new file mode 100644 index 0000000..e9afb6f --- /dev/null +++ b/tests/diagnose_spmm_opt.py @@ -0,0 +1,628 @@ +""" +Batch-capable diagnostics for CSR SpMM-opt native accuracy. + +Usage: + python tests/diagnose_spmm_opt.py path/to/matrix.mtx --dense-cols 32 --seed 0 + python tests/diagnose_spmm_opt.py path/to/mtx_dir --dense-cols 32 --seed 0 --out-dir diag_out + python tests/diagnose_spmm_opt.py path/to/mtx_dir --csv spmm_opt.csv --only-status fail --out-dir diag_out +""" + +import argparse +import csv +import glob +import os +import sys +from pathlib import Path + +import torch + +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + +import flagsparse as fs +import flagsparse.sparse_operations.spmm_csr as fs_spmm_csr +from test_spmm_opt import _seeded_dense_matrix, load_mtx_to_csr_torch + + +SUMMARY_FIELDS = [ + "matrix", + "value_dtype", + "dense_cols", + "seed", + "n_rows", + "n_cols", + "nnz", + "avg_nnz_per_row", + "base_err_native", + "base_err_hp", + "legacy_err_native", + "candidate_err_native", + "legacy_err_hp", + "candidate_err_hp", + "legacy_rows_native_gt_1", + "candidate_rows_native_gt_1", + "legacy_rows_native_gt_0_5", + "candidate_rows_native_gt_0_5", + "legacy_worst_row", + "candidate_worst_row", + "legacy_worst_row_nnz", + "candidate_worst_row_nnz", + "legacy_worst_col", + "candidate_worst_col", + "legacy_worst_native_ratio", + "candidate_worst_native_ratio", + "legacy_status", + "candidate_status", + "status", +] + +ROW_FIELDS = [ + "row", + "row_nnz", + "bucket_kind", + "legacy_max_native_ratio", + "candidate_max_native_ratio", + "legacy_mean_native_ratio", + "candidate_mean_native_ratio", + "legacy_max_hp_ratio", + "candidate_max_hp_ratio", + "legacy_max_abs_native_diff", + "candidate_max_abs_native_diff", + "legacy_max_abs_hp_diff", + "candidate_max_abs_hp_diff", + "legacy_worst_col", + "candidate_worst_col", + "legacy_worst_native_ref", + "candidate_worst_native_ref", + "legacy_worst_hp_ref", + "candidate_worst_hp_ref", + "legacy_worst_opt", + "candidate_worst_opt", +] + +BUCKET_FIELDS = [ + "bucket_kind", + "row_count", + "row_nnz_min", + "row_nnz_mean", + "row_nnz_max", + "legacy_bad_rows_native_gt_1", + "candidate_bad_rows_native_gt_1", + "legacy_bad_rows_native_gt_0_5", + "candidate_bad_rows_native_gt_0_5", + "legacy_worst_native_ratio", + "candidate_worst_native_ratio", + "legacy_p50_native_ratio", + "candidate_p50_native_ratio", + "legacy_p90_native_ratio", + "candidate_p90_native_ratio", + "legacy_p99_native_ratio", + "candidate_p99_native_ratio", + "legacy_p99_hp_ratio", + "candidate_p99_hp_ratio", +] + + +def _dtype_from_name(name): + return {"float32": torch.float32, "float64": torch.float64}[name] + + +def _ensure_parent_dir(path): + parent = os.path.dirname(os.path.abspath(path)) + if parent: + os.makedirs(parent, exist_ok=True) + + +def _write_csv(path, rows, fieldnames): + _ensure_parent_dir(path) + with open(path, "w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow(row) + + +def _normalize_matrix_name(name): + base = os.path.basename(str(name).strip()) + if not base: + return "" + return base if base.endswith(".mtx") else f"{base}.mtx" + + +def _safe_stem(path): + return os.path.splitext(os.path.basename(path))[0] + + +def _resolve_input_paths(input_path): + if os.path.isfile(input_path): + if not str(input_path).lower().endswith(".mtx"): + raise ValueError(f"Input file must be a .mtx file, got: {input_path}") + return [os.path.abspath(input_path)] + if os.path.isdir(input_path): + paths = sorted(glob.glob(os.path.join(input_path, "*.mtx"))) + if not paths: + raise ValueError(f"No .mtx files found under directory: {input_path}") + return [os.path.abspath(path) for path in paths] + raise FileNotFoundError(f"Input path does not exist: {input_path}") + + +def _load_status_metadata(csv_path): + if csv_path is None: + return {} + metadata = {} + with open(csv_path, "r", encoding="utf-8") as handle: + reader = csv.DictReader(handle) + for row in reader: + matrix = _normalize_matrix_name(row.get("matrix", "")) + if matrix: + metadata[matrix] = row + return metadata + + +def _build_native_reference(data, indices, indptr, B, shape, dtype): + sparse = torch.sparse_csr_tensor( + indptr.to(torch.int64), + indices.to(torch.int64), + data.to(dtype), + size=shape, + device=data.device, + ) + return torch.sparse.mm(sparse, B.to(dtype)).to(dtype) + + +def _build_hp_reference(data, indices, indptr, B, shape, dtype): + if dtype == torch.float32: + ref_dtype = torch.float64 + else: + ref_dtype = dtype + sparse = torch.sparse_csr_tensor( + indptr.to(torch.int64), + indices.to(torch.int64), + data.to(ref_dtype), + size=shape, + device=data.device, + ) + return torch.sparse.mm(sparse, B.to(ref_dtype)).to(dtype) + + +def _reference_tolerance(dtype): + if dtype == torch.float32: + return 1e-4, 1e-2 + return 1e-12, 1e-10 + + +def _quantiles(values): + if values.numel() == 0: + return 0.0, 0.0, 0.0 + compare = values.to(torch.float64) + q = torch.quantile( + compare, + torch.tensor([0.5, 0.9, 0.99], device=compare.device, dtype=compare.dtype), + ) + return float(q[0].item()), float(q[1].item()), float(q[2].item()) + + +def _describe_quantiles(values): + if values.numel() == 0: + return "empty" + p50, p90, p99 = _quantiles(values) + compare = values.to(torch.float64) + return ( + f"min={float(compare.min().item()):.2f}, " + f"p50={p50:.2f}, " + f"p90={p90:.2f}, " + f"p99={p99:.2f}, " + f"max={float(compare.max().item()):.2f}" + ) + + +def _compute_error_metrics(candidate, reference, dtype): + atol, rtol = _reference_tolerance(dtype) + diff = torch.abs(candidate - reference).to(torch.float64) + denom = (atol + rtol * torch.abs(reference)).to(torch.float64) + ratio = diff / denom + row_max_ratio = torch.max(ratio, dim=1).values + row_mean_ratio = torch.mean(ratio, dim=1) + row_max_abs = torch.max(diff, dim=1).values + row_worst_col = torch.argmax(ratio, dim=1) + if row_max_ratio.numel() == 0: + worst_row = 0 + worst_col = 0 + worst_ratio = 0.0 + else: + worst_row = int(torch.argmax(row_max_ratio).item()) + worst_col = int(row_worst_col[worst_row].item()) + worst_ratio = float(row_max_ratio[worst_row].item()) + return { + "global_err": float(torch.max(ratio).item()) if ratio.numel() > 0 else 0.0, + "ratio": ratio, + "row_max_ratio": row_max_ratio, + "row_mean_ratio": row_mean_ratio, + "row_max_abs": row_max_abs, + "row_worst_col": row_worst_col, + "rows_err_gt_1": int(torch.count_nonzero(row_max_ratio > 1.0).item()), + "rows_err_gt_0_5": int(torch.count_nonzero(row_max_ratio > 0.5).item()), + "worst_row": worst_row, + "worst_col": worst_col, + "worst_ratio": worst_ratio, + "status": "PASS" if (worst_ratio <= 1.0) else "FAIL", + } + + +def _candidate_bucket_rows(prepared): + bucket_rows = {} + for bucket in fs_spmm_csr._build_spmm_opt_candidate_buckets(prepared): + bucket_rows[bucket["label"]] = bucket["rows"] + return bucket_rows + + +def _row_bucket_names(prepared): + names = ["unassigned"] * prepared.n_rows + for label, rows in _candidate_bucket_rows(prepared).items(): + for row in rows.to(torch.int64).cpu().tolist(): + names[row] = label + return names + + +def _bucket_row_nnz(prepared, rows): + if rows.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=prepared.row_lengths.device) + return prepared.row_lengths.index_select(0, rows.to(prepared.row_lengths.device)).to(torch.int64) + + +def _build_bucket_rows(prepared, legacy_native, candidate_native, legacy_hp, candidate_hp): + rows = [] + for label, bucket_rows in _candidate_bucket_rows(prepared).items(): + if bucket_rows.numel() == 0: + continue + bucket_row_nnz = _bucket_row_nnz(prepared, bucket_rows) + rows_device = bucket_rows.to(legacy_native["row_max_ratio"].device) + legacy_native_ratio = legacy_native["row_max_ratio"].index_select(0, rows_device) + candidate_native_ratio = candidate_native["row_max_ratio"].index_select(0, rows_device) + legacy_hp_ratio = legacy_hp["row_max_ratio"].index_select(0, rows_device) + candidate_hp_ratio = candidate_hp["row_max_ratio"].index_select(0, rows_device) + legacy_p50_native, legacy_p90_native, legacy_p99_native = _quantiles(legacy_native_ratio) + candidate_p50_native, candidate_p90_native, candidate_p99_native = _quantiles(candidate_native_ratio) + _, _, legacy_p99_hp = _quantiles(legacy_hp_ratio) + _, _, candidate_p99_hp = _quantiles(candidate_hp_ratio) + rows.append( + { + "bucket_kind": label, + "row_count": int(bucket_rows.numel()), + "row_nnz_min": int(bucket_row_nnz.min().item()), + "row_nnz_mean": float(bucket_row_nnz.to(torch.float64).mean().item()), + "row_nnz_max": int(bucket_row_nnz.max().item()), + "legacy_bad_rows_native_gt_1": int(torch.count_nonzero(legacy_native_ratio > 1.0).item()), + "candidate_bad_rows_native_gt_1": int(torch.count_nonzero(candidate_native_ratio > 1.0).item()), + "legacy_bad_rows_native_gt_0_5": int(torch.count_nonzero(legacy_native_ratio > 0.5).item()), + "candidate_bad_rows_native_gt_0_5": int(torch.count_nonzero(candidate_native_ratio > 0.5).item()), + "legacy_worst_native_ratio": float(legacy_native_ratio.max().item()), + "candidate_worst_native_ratio": float(candidate_native_ratio.max().item()), + "legacy_p50_native_ratio": legacy_p50_native, + "candidate_p50_native_ratio": candidate_p50_native, + "legacy_p90_native_ratio": legacy_p90_native, + "candidate_p90_native_ratio": candidate_p90_native, + "legacy_p99_native_ratio": legacy_p99_native, + "candidate_p99_native_ratio": candidate_p99_native, + "legacy_p99_hp_ratio": legacy_p99_hp, + "candidate_p99_hp_ratio": candidate_p99_hp, + } + ) + return rows + + +def _build_row_rows( + prepared, + native_ref, + hp_ref, + legacy_out, + candidate_out, + legacy_native, + candidate_native, + legacy_hp, + candidate_hp, + bucket_names, +): + rows = [] + for row_id in range(prepared.n_rows): + legacy_col = int(legacy_native["row_worst_col"][row_id].item()) + candidate_col = int(candidate_native["row_worst_col"][row_id].item()) + rows.append( + { + "row": row_id, + "row_nnz": int(prepared.row_lengths[row_id].item()), + "bucket_kind": bucket_names[row_id], + "legacy_max_native_ratio": float(legacy_native["row_max_ratio"][row_id].item()), + "candidate_max_native_ratio": float(candidate_native["row_max_ratio"][row_id].item()), + "legacy_mean_native_ratio": float(legacy_native["row_mean_ratio"][row_id].item()), + "candidate_mean_native_ratio": float(candidate_native["row_mean_ratio"][row_id].item()), + "legacy_max_hp_ratio": float(legacy_hp["row_max_ratio"][row_id].item()), + "candidate_max_hp_ratio": float(candidate_hp["row_max_ratio"][row_id].item()), + "legacy_max_abs_native_diff": float(legacy_native["row_max_abs"][row_id].item()), + "candidate_max_abs_native_diff": float(candidate_native["row_max_abs"][row_id].item()), + "legacy_max_abs_hp_diff": float(legacy_hp["row_max_abs"][row_id].item()), + "candidate_max_abs_hp_diff": float(candidate_hp["row_max_abs"][row_id].item()), + "legacy_worst_col": legacy_col, + "candidate_worst_col": candidate_col, + "legacy_worst_native_ref": float(native_ref[row_id, legacy_col].item()), + "candidate_worst_native_ref": float(native_ref[row_id, candidate_col].item()), + "legacy_worst_hp_ref": float(hp_ref[row_id, legacy_col].item()), + "candidate_worst_hp_ref": float(hp_ref[row_id, candidate_col].item()), + "legacy_worst_opt": float(legacy_out[row_id, legacy_col].item()), + "candidate_worst_opt": float(candidate_out[row_id, candidate_col].item()), + } + ) + return rows + + +def _print_matrix_summary(summary_row, prepared, legacy_native, candidate_native, bucket_rows): + print("=" * 120) + print(f"Matrix: {summary_row['matrix']}") + print( + f"dtype={summary_row['value_dtype']} dense_cols={summary_row['dense_cols']} " + f"seed={summary_row['seed']}" + ) + print( + f"shape=({summary_row['n_rows']}, {summary_row['n_cols']}) " + f"nnz={summary_row['nnz']} avg_nnz_per_row={summary_row['avg_nnz_per_row']:.2f}" + ) + print( + f"base_err_native={summary_row['base_err_native']:.6f} " + f"base_err_hp={summary_row['base_err_hp']:.6f}" + ) + print( + f"legacy_err_native={summary_row['legacy_err_native']:.6f} " + f"candidate_err_native={summary_row['candidate_err_native']:.6f} " + f"legacy_err_hp={summary_row['legacy_err_hp']:.6f} " + f"candidate_err_hp={summary_row['candidate_err_hp']:.6f} " + f"status={summary_row['status']}" + ) + print(f"row_nnz_quantiles: {_describe_quantiles(prepared.row_lengths)}") + print(f"legacy_native_row_err_quantiles: {_describe_quantiles(legacy_native['row_max_ratio'])}") + print(f"candidate_native_row_err_quantiles: {_describe_quantiles(candidate_native['row_max_ratio'])}") + print( + f"legacy rows native err>1: {summary_row['legacy_rows_native_gt_1']} / {summary_row['n_rows']} " + f"candidate rows native err>1: {summary_row['candidate_rows_native_gt_1']} / {summary_row['n_rows']}" + ) + print("-" * 120) + print("Bucket summary") + for bucket_row in bucket_rows: + print( + f"kind={bucket_row['bucket_kind']:<11} rows={bucket_row['row_count']:>8} " + f"row_nnz[min/mean/max]={bucket_row['row_nnz_min']:>4}/" + f"{bucket_row['row_nnz_mean']:>8.2f}/" + f"{bucket_row['row_nnz_max']:>4} " + f"legacy_bad_gt_1={bucket_row['legacy_bad_rows_native_gt_1']:>6} " + f"candidate_bad_gt_1={bucket_row['candidate_bad_rows_native_gt_1']:>6} " + f"legacy_worst={bucket_row['legacy_worst_native_ratio']:>10.4f} " + f"candidate_worst={bucket_row['candidate_worst_native_ratio']:>10.4f}" + ) + + +def _print_top_rows(row_rows, topk): + top_rows = sorted( + row_rows, + key=lambda row: float(row["candidate_max_native_ratio"]), + reverse=True, + )[:topk] + print("-" * 120) + print(f"Top-{len(top_rows)} rows by candidate native ratio") + for rank, row in enumerate(top_rows, start=1): + print( + f"{rank:>2}. row={row['row']:>8} row_nnz={row['row_nnz']:>6} " + f"bucket={row['bucket_kind']:<11} " + f"legacy_native={row['legacy_max_native_ratio']:>10.4f} " + f"candidate_native={row['candidate_max_native_ratio']:>10.4f} " + f"legacy_hp={row['legacy_max_hp_ratio']:>10.4f} " + f"candidate_hp={row['candidate_max_hp_ratio']:>10.4f}" + ) + + +def diagnose_one(path, dense_cols, dtype, seed, topk): + device = torch.device("cuda") + data, indices, indptr, shape = load_mtx_to_csr_torch(path, dtype=dtype, device=device) + n_rows, n_cols = shape + B = _seeded_dense_matrix((n_cols, dense_cols), dtype, device, seed) + + native_ref = _build_native_reference(data, indices, indptr, B, shape, dtype) + hp_ref = _build_hp_reference(data, indices, indptr, B, shape, dtype) + base = fs.flagsparse_spmm_csr(data, indices, indptr, B, shape) + prepared = fs.prepare_spmm_csr_opt(data, indices, indptr, shape) + legacy_out = fs.flagsparse_spmm_csr_opt(B=B, prepared=prepared) + candidate_out = fs_spmm_csr._flagsparse_spmm_csr_opt_candidate_for_diagnose(prepared, B) + + base_native = _compute_error_metrics(base, native_ref, dtype) + base_hp = _compute_error_metrics(base, hp_ref, dtype) + legacy_native = _compute_error_metrics(legacy_out, native_ref, dtype) + candidate_native = _compute_error_metrics(candidate_out, native_ref, dtype) + legacy_hp = _compute_error_metrics(legacy_out, hp_ref, dtype) + candidate_hp = _compute_error_metrics(candidate_out, hp_ref, dtype) + bucket_names = _row_bucket_names(prepared) + + summary_row = { + "matrix": os.path.basename(path), + "value_dtype": str(dtype).replace("torch.", ""), + "dense_cols": int(dense_cols), + "seed": seed, + "n_rows": int(n_rows), + "n_cols": int(n_cols), + "nnz": int(data.numel()), + "avg_nnz_per_row": float(data.numel()) / max(1, int(n_rows)), + "base_err_native": base_native["global_err"], + "base_err_hp": base_hp["global_err"], + "legacy_err_native": legacy_native["global_err"], + "candidate_err_native": candidate_native["global_err"], + "legacy_err_hp": legacy_hp["global_err"], + "candidate_err_hp": candidate_hp["global_err"], + "legacy_rows_native_gt_1": legacy_native["rows_err_gt_1"], + "candidate_rows_native_gt_1": candidate_native["rows_err_gt_1"], + "legacy_rows_native_gt_0_5": legacy_native["rows_err_gt_0_5"], + "candidate_rows_native_gt_0_5": candidate_native["rows_err_gt_0_5"], + "legacy_worst_row": legacy_native["worst_row"], + "candidate_worst_row": candidate_native["worst_row"], + "legacy_worst_row_nnz": int(prepared.row_lengths[legacy_native["worst_row"]].item()) if n_rows > 0 else 0, + "candidate_worst_row_nnz": int(prepared.row_lengths[candidate_native["worst_row"]].item()) if n_rows > 0 else 0, + "legacy_worst_col": legacy_native["worst_col"], + "candidate_worst_col": candidate_native["worst_col"], + "legacy_worst_native_ratio": legacy_native["worst_ratio"], + "candidate_worst_native_ratio": candidate_native["worst_ratio"], + "legacy_status": legacy_native["status"], + "candidate_status": candidate_native["status"], + "status": candidate_native["status"], + } + + row_rows = _build_row_rows( + prepared, + native_ref, + hp_ref, + legacy_out, + candidate_out, + legacy_native, + candidate_native, + legacy_hp, + candidate_hp, + bucket_names, + ) + bucket_rows = _build_bucket_rows(prepared, legacy_native, candidate_native, legacy_hp, candidate_hp) + + _print_matrix_summary(summary_row, prepared, legacy_native, candidate_native, bucket_rows) + _print_top_rows(row_rows, max(1, min(int(topk), len(row_rows)))) + return summary_row, row_rows, bucket_rows + + +def _collect_selected_paths(all_paths, csv_metadata, only_status, only_matrices): + allowed_names = None + if only_matrices: + allowed_names = { + _normalize_matrix_name(name) + for name in only_matrices.split(",") + if _normalize_matrix_name(name) + } + + selected = [] + for path in all_paths: + matrix_name = _normalize_matrix_name(os.path.basename(path)) + if allowed_names is not None and matrix_name not in allowed_names: + continue + if only_status != "all" and matrix_name in csv_metadata: + csv_status = str(csv_metadata[matrix_name].get("status", "")).strip().lower() + if csv_status != only_status: + continue + selected.append(path) + return selected + + +def _filter_results_by_status(results, only_status): + if only_status == "all": + return results + target = only_status.upper() + return [result for result in results if result["summary"]["status"] == target] + + +def _default_output_paths(out_dir): + abs_out_dir = os.path.abspath(out_dir) + rows_dir = os.path.join(abs_out_dir, "diag_rows") + buckets_dir = os.path.join(abs_out_dir, "diag_buckets") + summary_csv = os.path.join(abs_out_dir, "spmm_diag_summary.csv") + os.makedirs(rows_dir, exist_ok=True) + os.makedirs(buckets_dir, exist_ok=True) + return summary_csv, rows_dir, buckets_dir + + +def run_batch(input_path, dtype, dense_cols, seed, topk, csv_path=None, out_dir="spmm_diag", only_status="all", only_matrices=None): + all_paths = _resolve_input_paths(input_path) + csv_metadata = _load_status_metadata(csv_path) + selected_paths = _collect_selected_paths(all_paths, csv_metadata, only_status, only_matrices) + if not selected_paths: + raise ValueError("No matrices matched the requested filters.") + + summary_csv, rows_dir, buckets_dir = _default_output_paths(out_dir) + results = [] + for index, path in enumerate(selected_paths, start=1): + print(f"[{index}/{len(selected_paths)}] Diagnosing {os.path.basename(path)}") + summary_row, row_rows, bucket_rows = diagnose_one( + path=path, + dense_cols=dense_cols, + dtype=dtype, + seed=seed, + topk=topk, + ) + results.append({"path": path, "summary": summary_row, "rows": row_rows, "buckets": bucket_rows}) + + filtered_results = _filter_results_by_status(results, only_status) + if not filtered_results: + raise ValueError("Diagnostics completed, but no matrices matched the requested status filter.") + + _write_csv(summary_csv, [result["summary"] for result in filtered_results], SUMMARY_FIELDS) + for result in filtered_results: + stem = _safe_stem(result["path"]) + _write_csv(os.path.join(rows_dir, f"{stem}.rows.csv"), result["rows"], ROW_FIELDS) + _write_csv(os.path.join(buckets_dir, f"{stem}.buckets.csv"), result["buckets"], BUCKET_FIELDS) + + print("=" * 120) + print(f"Wrote summary CSV to {summary_csv}") + print(f"Wrote row diagnostics under {rows_dir}") + print(f"Wrote bucket diagnostics under {buckets_dir}") + + +def main(): + parser = argparse.ArgumentParser(description="Diagnose native CSR SpMM-opt accuracy for one matrix or a directory of matrices.") + parser.add_argument("input_path", help="Path to one .mtx file or a directory containing .mtx files") + parser.add_argument("--dtype", default="float32", choices=["float32", "float64"]) + parser.add_argument("--dense-cols", type=int, default=32) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--topk", type=int, default=10, help="How many top bad rows to print") + parser.add_argument("--csv", type=str, default=None, help="Optional existing CSV used for metadata or prefiltering") + parser.add_argument("--out-dir", type=str, default="spmm_diag", help="Output directory for batch summary and detail CSVs") + parser.add_argument("--only-status", choices=["all", "fail", "pass"], default="all", help="If CSV metadata is provided, prefilter by CSV status; otherwise filter final outputs by diagnosed status") + parser.add_argument("--only-matrices", type=str, default=None, help="Comma-separated matrix basenames to include, with or without .mtx suffix") + parser.add_argument("--row-csv", type=str, default=None, help="Single-file mode only: explicit output path for row diagnostics CSV") + args = parser.parse_args() + + dtype = _dtype_from_name(args.dtype) + input_is_file = os.path.isfile(args.input_path) + + if input_is_file: + summary_row, row_rows, bucket_rows = diagnose_one( + path=args.input_path, + dense_cols=args.dense_cols, + dtype=dtype, + seed=args.seed, + topk=args.topk, + ) + summary_csv, rows_dir, buckets_dir = _default_output_paths(args.out_dir) + stem = _safe_stem(args.input_path) + _write_csv(summary_csv, [summary_row], SUMMARY_FIELDS) + if args.row_csv: + _write_csv(args.row_csv, row_rows, ROW_FIELDS) + print("-" * 120) + print(f"Wrote row diagnostics to {args.row_csv}") + else: + _write_csv(os.path.join(rows_dir, f"{stem}.rows.csv"), row_rows, ROW_FIELDS) + print("-" * 120) + print(f"Wrote row diagnostics to {os.path.join(rows_dir, f'{stem}.rows.csv')}") + _write_csv(os.path.join(buckets_dir, f"{stem}.buckets.csv"), bucket_rows, BUCKET_FIELDS) + print(f"Wrote summary CSV to {summary_csv}") + print(f"Wrote bucket diagnostics to {os.path.join(buckets_dir, f'{stem}.buckets.csv')}") + return + + if args.row_csv: + raise ValueError("--row-csv is only supported in single-file mode; use --out-dir in directory mode") + + run_batch( + input_path=args.input_path, + dtype=dtype, + dense_cols=args.dense_cols, + seed=args.seed, + topk=args.topk, + csv_path=args.csv, + out_dir=args.out_dir, + only_status=args.only_status, + only_matrices=args.only_matrices, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/pytest/Note.md b/tests/pytest/Note.md new file mode 100644 index 0000000..d5418ef --- /dev/null +++ b/tests/pytest/Note.md @@ -0,0 +1,189 @@ +# tests/pytest:参数化用例说明 / Parametrized test notes + +--- + +## 中文 + +### 用例名里的方括号 `[…]` 是什么 + +形如 `test_xxx[float32-1024-256]` 的片段是 **pytest 自动生成的用例实例 ID**:把这一次 **`@pytest.mark.parametrize` 选用的实参** 用 `-` 拼成一段,便于区分笛卡尔积里的某一组。 + +- **dtype**:`FLOAT_DTYPE_IDS` 用于 gather/scatter 与 CSR SpMV;COO SpMV 使用 `SPMV_COO_DTYPE_IDS`(仅 `float32` / `float64`)。 +- **数字段**:对应当前测试里形状或阶数参数,含义见下文。 + +参数列表定义在 `param_shapes.py`;`--mode quick|normal` 在 `conftest.py` 中切换 **QUICK_MODE**,会缩短形状列表。 + +--- + +### 合成数据生成规则(CUDA) + +全部参数化用例在 **GPU** 上构造张量(`torch.device("cuda")`),**不**读 `.mtx`、ODPS 或外部矩阵文件。规则与对应 `test_*_accuracy.py` 一致,摘要如下: + +- **公共**:掩码采样概率(CSR / COO SpMV 共用)对形状 `(rows, cols)` 取 `p = min(0.25, max(0.06, 32 / max(rows*cols, 1)))`;`mask = (torch.rand(rows, cols, device=cuda) < p)`;若 `mask` 全假则强制 `mask[0, 0] = True`;非零处值为 `randn(rows, cols, dtype, device=cuda) * mask`。 +- **CSR SpMV**:由上式得到稠密阵后 `to_sparse_csr()` 得 **A**;`x = randn(N, dtype, device=cuda)`;参照 `torch.sparse.mm(A_csr, x.unsqueeze(1)).squeeze(1)`;被测 `flagsparse_spmv_csr(...)`。 +- **COO SpMV**:**同一掩码与 `randn*mask` 的稠密阵**再 `to_sparse_coo()`;`x` 与参照同上(`torch.sparse.mm` 作用于该 COO 张量);仅 **float32 / float64**;被测 `flagsparse_spmv_coo(...)`。 +- **Gather**:`dense = randn(dense_size, dtype, device=cuda)`;`nnz_eff = min(nnz, dense_size)`;`indices = randperm(dense_size, device=cuda)[:nnz_eff].to(int32)`;参照为 `dense[indices.to(int64)]`。 +- **Scatter**:`vals = randn(nnz_eff, ...)`,索引规则与 gather 相同;输出初值为全零;参照为对同一下标的 `index_copy_`。 +- **SpSV(CSR)**:`base = tril(randn(n, n, dtype, device=cuda))`,`A = base + eye(n) * (n/2 + 2)`(加强对角),`b = randn(n, ...)`;`A_csr = A.to_sparse_csr()`;参照 `torch.linalg.solve_triangular(A, b.unsqueeze(-1), upper=False).squeeze(-1)`;被测 `flagsparse_spsv_csr(...)`(仅下三角、非单位对角)。 + +容差:`gather`/`scatter` 用 `torch.equal`;SpMV / SpSV 用 `torch.allclose`,阈值见各测试文件中的 `_tol` 或固定 `rtol`/`atol`。 + +--- + +### `test_gather_matches_indexing` + +**ID 格式**:`[dtype-dense_size-nnz]` + +| 段 | 含义 | +|----|------| +| dtype | 稠密向量与输出流的数据类型 | +| dense_size | 被索引的稠密向量长度 | +| nnz | 索引条数(gather 输出长度);代码里会做 `min(nnz, dense_size)` | + +**形状来源**:`GATHER_SCATTER_SHAPES` × `FLOAT_DTYPES`。 + +--- + +### `test_scatter_matches_index_copy` + +**ID 格式**:与 gather 相同 — `[dtype-dense_size-nnz]`。 + +测的是 **scatter**(覆盖写)与 `index_copy_` 的一致性;索引为随机 **不重复** 子集。 + +**形状来源**:同上。 + +--- + +### `test_spmv_csr_matches_torch` + +**ID 格式**:`[dtype-M-N]` + +| 段 | 含义 | +|----|------| +| dtype | 稀疏矩阵与向量 **x** 的数据类型 | +| M | **A** 行数,**y** 的长度 | +| N | **A** 列数,**x** 的长度(`y = A @ x`) | + +**形状来源**:`SPMV_MN_SHAPES` × `FLOAT_DTYPES`。 + +--- + +### `test_spmv_coo_matches_torch` + +**ID 格式**:`[dtype-M-N]`(与 CSR SpMV 相同) + +| 段 | 含义 | +|----|------| +| dtype | 仅 **float32 / float64**(与 Triton COO SpMV 核一致;见 `SPMV_COO_DTYPES`) | +| M | **A** 行数,**y** 长度 | +| N | **A** 列数,**x** 长度 | + +**形状来源**:`SPMV_MN_SHAPES` × `SPMV_COO_DTYPES`。数据由与 CSR 相同的掩码规则得到稠密矩阵后 `to_sparse_coo()`。 + +--- + +### `test_spsv_csr_lower_matches_dense` + +**ID 格式**:`[dtype-n]`(dtype 在测试里显式写为 `float32` / `float64`) + +| 段 | 含义 | +|----|------| +| dtype | 仅 **float32 / float64**(`flagsparse_spsv_csr` / CSR) | +| n | 方阵阶数:**A** 为 `n×n` 下三角 CSR,**b** 长度 `n`,解 **A x = b** | + +**形状来源**:`SPSV_N` × `{float32, float64}`。**仅 CSR**:无 COO SpSV 参数化用例。 + +--- + +## English + +### What the `[…]` suffix means + +Strings like `test_xxx[float32-1024-256]` are **pytest case instance IDs**: the **concrete arguments** chosen by `@pytest.mark.parametrize` for that run, joined with `-`, so you can tell one combination in the Cartesian product from another. + +- **dtype**: `FLOAT_DTYPE_IDS` for gather/scatter and CSR SpMV; COO SpMV uses `SPMV_COO_DTYPE_IDS` (`float32` / `float64` only). +- **Numeric segments**: Shape or order parameters for that test (see below). + +Grids live in `param_shapes.py`. `--mode quick|normal` in `conftest.py` toggles **QUICK_MODE** (fewer shapes). + +--- + +### Synthetic data construction (CUDA) + +All parametrized tests build tensors on **`torch.device("cuda")`** and do **not** read `.mtx`, ODPS, or external matrix files. This matches the `test_*_accuracy.py` sources; summary: + +- **Shared SpMV mask** (CSR & COO): for shape `(rows, cols)`, `p = min(0.25, max(0.06, 32 / max(rows*cols, 1)))`; `mask = (torch.rand(rows, cols, device=cuda) < p)`; if the mask is all false, set `mask[0, 0] = True`; values are `randn(rows, cols, dtype, device=cuda) * mask`. +- **CSR SpMV**: dense field → `to_sparse_csr()` → **A**; `x = randn(N, dtype, device=cuda)`; reference `torch.sparse.mm(A_csr, x.unsqueeze(1)).squeeze(1)`; DUT `flagsparse_spmv_csr(...)`. +- **COO SpMV**: **same** masked dense field → `to_sparse_coo()`; same `x` / reference path with `torch.sparse.mm` on that COO tensor; dtypes **float32 / float64** only; DUT `flagsparse_spmv_coo(...)`. +- **Gather**: `dense = randn(dense_size, dtype, device=cuda)`; `nnz_eff = min(nnz, dense_size)`; `indices = randperm(dense_size, device=cuda)[:nnz_eff].to(int32)`; reference fancy indexing on `dense`. +- **Scatter**: `vals = randn(nnz_eff, ...)` with the same index rule; output starts at zeros; reference `index_copy_` on the same indices. +- **SpSV (CSR)**: `base = tril(randn(n, n, dtype, device=cuda))`, `A = base + eye(n) * (n/2 + 2)`, `b = randn(n, ...)`; `A_csr = A.to_sparse_csr()`; reference `torch.linalg.solve_triangular(A, b.unsqueeze(-1), upper=False).squeeze(-1)`; DUT `flagsparse_spsv_csr(...)` (lower-triangular, non-unit diagonal). + +Checks: `gather` / `scatter` use `torch.equal`; SpMV / SpSV use `torch.allclose` (see per-file `_tol` or fixed rtol/atol). + +--- + +### `test_gather_matches_indexing` + +**ID pattern**: `[dtype-dense_size-nnz]` + +| Segment | Meaning | +|---------|---------| +| dtype | Dtype of the dense vector and gathered stream | +| dense_size | Length of the indexed dense vector | +| nnz | Number of indices (output length); code uses `min(nnz, dense_size)` | + +**Source**: `GATHER_SCATTER_SHAPES` × `FLOAT_DTYPES`. + +--- + +### `test_scatter_matches_index_copy` + +**ID pattern**: same as gather — `[dtype-dense_size-nnz]`. + +Checks **scatter** (overwrite) vs `index_copy_` with a random **unique** index set. + +**Source**: same as gather. + +--- + +### `test_spmv_csr_matches_torch` + +**ID pattern**: `[dtype-M-N]` + +| Segment | Meaning | +|---------|---------| +| dtype | Dtype of **A** and **x** | +| M | Rows of **A**, length of **y** | +| N | Columns of **A**, length of **x** (`y = A @ x`) | + +**Source**: `SPMV_MN_SHAPES` × `FLOAT_DTYPES`. + +--- + +### `test_spmv_coo_matches_torch` + +**ID pattern**: `[dtype-M-N]` (same as CSR SpMV) + +| Segment | Meaning | +|---------|---------| +| dtype | **float32 / float64** only (Triton COO SpMV kernels; see `SPMV_COO_DTYPES`) | +| M | Rows of **A**, length of **y** | +| N | Columns of **A**, length of **x** | + +**Source**: `SPMV_MN_SHAPES` × `SPMV_COO_DTYPES`. Matrix is the same masked dense field as CSR SpMV, then `to_sparse_coo()`. + +--- + +### `test_spsv_csr_lower_matches_dense` + +**ID pattern**: `[dtype-n]` (dtype ids are explicitly `float32` / `float64`) + +| Segment | Meaning | +|---------|---------| +| dtype | **float32 / float64** only (`flagsparse_spsv_csr`, CSR storage) | +| n | Square order: lower-triangular **CSR** `n×n`, **b** length `n`, solve **A x = b** | + +**Source**: `SPSV_N` × `{float32, float64}`. **CSR SpSV only**; no COO SpSV parametrized test. + +--- diff --git a/tests/pytest/__init__.py b/tests/pytest/__init__.py new file mode 100644 index 0000000..24486c3 --- /dev/null +++ b/tests/pytest/__init__.py @@ -0,0 +1 @@ +"""Parametrized pytest suite (marks + shapes/dtypes patterned after FlagGems).""" diff --git a/tests/pytest/conftest.py b/tests/pytest/conftest.py new file mode 100644 index 0000000..5bfb993 --- /dev/null +++ b/tests/pytest/conftest.py @@ -0,0 +1,20 @@ +"""Pytest hooks: ``--mode quick|normal`` toggles shape/dtype lists ``QUICK_MODE``.""" + +import pytest + +QUICK_MODE = False + + +def pytest_addoption(parser): + parser.addoption( + "--mode", + action="store", + default="normal", + choices=["normal", "quick"], + help="quick: fewer shapes/dtypes (FlagGems-style QUICK_MODE).", + ) + + +def pytest_configure(config): + global QUICK_MODE + QUICK_MODE = config.getoption("--mode") == "quick" diff --git a/tests/pytest/param_shapes.py b/tests/pytest/param_shapes.py new file mode 100644 index 0000000..57bcf9b --- /dev/null +++ b/tests/pytest/param_shapes.py @@ -0,0 +1,47 @@ +"""Shape / dtype grids for parametrized tests (gather/scatter, CSR+COO SpMV, CSR SpSV).""" + +import torch + +from tests.pytest.conftest import QUICK_MODE + + +if QUICK_MODE: + SPMV_MN_SHAPES = [ + (1, 32), + ] + GATHER_SCATTER_SHAPES = [ + (512, 128), + ] + SPSV_N = [16] +else: + SPMV_MN_SHAPES = [ + (1, 32), + (160, 1024), + (128, 256), + ] + GATHER_SCATTER_SHAPES = [ + (1024, 256), + (4096, 512), + ] + SPSV_N = [16, 64] + + +def _bf16_ok(): + if not torch.cuda.is_available(): + return False + fn = getattr(torch.cuda, "is_bf16_supported", None) + return bool(fn()) if callable(fn) else False + + +_PRIMARY_FLOAT = [torch.float16, torch.float32] +FLOAT_DTYPES = _PRIMARY_FLOAT + ([torch.bfloat16] if _bf16_ok() else []) + [torch.float64] + + +def _dtype_node_id(value): + """Short name for pytest node IDs (e.g. ``torch.float32`` -> ``float32``).""" + return str(value).replace("torch.", "") + + +FLOAT_DTYPE_IDS = [_dtype_node_id(d) for d in FLOAT_DTYPES] +SPMV_COO_DTYPES = (torch.float32, torch.float64) +SPMV_COO_DTYPE_IDS = ("float32", "float64") diff --git a/tests/pytest/test_gather_scatter_accuracy.py b/tests/pytest/test_gather_scatter_accuracy.py new file mode 100644 index 0000000..a30beec --- /dev/null +++ b/tests/pytest/test_gather_scatter_accuracy.py @@ -0,0 +1,192 @@ +import pytest +import torch + +from flagsparse import flagsparse_gather, flagsparse_scatter +from flagsparse.sparse_operations import gather_scatter as gather_scatter_ops + +from tests.pytest.param_shapes import FLOAT_DTYPES, GATHER_SCATTER_SHAPES + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +INDEX_DTYPES = [torch.int32, torch.int64] +INDEX_DTYPE_IDS = ["int32", "int64"] +RESET_OUTPUT_CASES = [True, False] +RESET_OUTPUT_IDS = ["reset", "inplace"] + +GATHER_DTYPE_CASES = [ + ("float", torch.float32), + ("double", torch.float64), + ("half", torch.float16), + ("bfloat16", torch.bfloat16), + ("complex64", torch.complex64), + ("complex128", torch.complex128), +] +GATHER_DTYPE_IDS = [name for name, _ in GATHER_DTYPE_CASES] + + +def _scatter_dtype_cases(): + cases = [(str(dtype).replace("torch.", ""), dtype) for dtype in FLOAT_DTYPES] + cases.append(("complex64", torch.complex64)) + cases.append(("complex128", torch.complex128)) + return cases + + +SCATTER_DTYPE_CASES = _scatter_dtype_cases() +SCATTER_DTYPE_IDS = [name for name, _ in SCATTER_DTYPE_CASES] + + +def _skip_unavailable_dtype(dtype_name, dtype): + if dtype is None: + pytest.skip(f"{dtype_name} dtype is unavailable in this torch build") + if dtype == torch.bfloat16 and not ( + torch.cuda.is_available() and torch.cuda.is_bf16_supported() + ): + pytest.skip("bfloat16 not supported on this GPU") + + +def _build_random_values(size, dtype, device): + if dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + return torch.randn(size, dtype=dtype, device=device) + if dtype == torch.complex64: + real = torch.randn(size, dtype=torch.float32, device=device) + imag = torch.randn(size, dtype=torch.float32, device=device) + return torch.complex(real, imag) + if dtype == torch.complex128: + real = torch.randn(size, dtype=torch.float64, device=device) + imag = torch.randn(size, dtype=torch.float64, device=device) + return torch.complex(real, imag) + raise TypeError(f"Unsupported dtype in test: {dtype}") + + +@pytest.mark.gather +@pytest.mark.parametrize("dense_size, nnz", GATHER_SCATTER_SHAPES) +@pytest.mark.parametrize("dtype_name,dtype", GATHER_DTYPE_CASES, ids=GATHER_DTYPE_IDS) +@pytest.mark.parametrize("index_dtype", INDEX_DTYPES, ids=INDEX_DTYPE_IDS) +def test_gather_matches_indexing(dense_size, nnz, dtype_name, dtype, index_dtype): + _skip_unavailable_dtype(dtype_name, dtype) + device = torch.device("cuda") + nnz = min(nnz, dense_size) + dense = _build_random_values(dense_size, dtype, device) + indices = torch.randperm(dense_size, device=device)[:nnz].to(index_dtype) + ref = dense[indices.to(torch.int64)] + got = flagsparse_gather(dense, indices) + assert torch.equal(ref, got) + + +@pytest.mark.gather +@pytest.mark.parametrize("index_dtype", INDEX_DTYPES, ids=INDEX_DTYPE_IDS) +def test_gather_complex128_matches_indexing(index_dtype): + device = torch.device("cuda") + dense_size = 4096 + nnz = 1024 + real = torch.randn(dense_size, dtype=torch.float64, device=device) + imag = torch.randn(dense_size, dtype=torch.float64, device=device) + dense = torch.complex(real, imag) + indices = torch.randperm(dense_size, device=device)[:nnz].to(index_dtype) + ref = dense.index_select(0, indices.to(torch.int64)) + got = flagsparse_gather(dense, indices) + assert torch.allclose(got, ref, atol=1e-10, rtol=1e-8) + + +@pytest.mark.scatter +@pytest.mark.parametrize("dense_size, nnz", GATHER_SCATTER_SHAPES) +@pytest.mark.parametrize("dtype_name,dtype", SCATTER_DTYPE_CASES, ids=SCATTER_DTYPE_IDS) +@pytest.mark.parametrize("index_dtype", INDEX_DTYPES, ids=INDEX_DTYPE_IDS) +@pytest.mark.parametrize("reset_output", RESET_OUTPUT_CASES, ids=RESET_OUTPUT_IDS) +def test_scatter_matches_index_copy( + dense_size, nnz, dtype_name, dtype, index_dtype, reset_output +): + _skip_unavailable_dtype(dtype_name, dtype) + device = torch.device("cuda") + nnz = min(nnz, dense_size) + vals = _build_random_values(nnz, dtype, device) + indices = torch.randperm(dense_size, device=device)[:nnz].to(index_dtype) + dense_ref = _build_random_values(dense_size, dtype, device) + dense = dense_ref.clone() + if reset_output: + dense_ref.zero_() + dense_ref.index_copy_(0, indices.to(torch.int64), vals) + flagsparse_scatter( + dense, + indices, + vals, + reset_output=reset_output, + dtype_policy="auto", + ) + assert torch.equal(dense, dense_ref) + + +@pytest.mark.scatter +def test_scatter_int64_auto_fallback_to_int32(monkeypatch): + device = torch.device("cuda") + dense_size = 257 + nnz = 129 + vals = torch.randn(nnz, dtype=torch.float32, device=device) + indices = torch.randperm(dense_size, device=device)[:nnz].to(torch.int64) + dense = torch.randn(dense_size, dtype=torch.float32, device=device) + dense_ref = dense.clone() + dense_ref.zero_() + dense_ref.index_copy_(0, indices.to(torch.int64), vals) + + original_launch = gather_scatter_ops._launch_triton_scatter_kernel + state = {"forced_once": False} + + def fake_launch(dense_values, sparse_values, kernel_indices, nnz, block_size=1024): + if kernel_indices.dtype == torch.int64 and not state["forced_once"]: + state["forced_once"] = True + raise RuntimeError("forced int64 launch failure") + return original_launch( + dense_values, + sparse_values, + kernel_indices, + nnz, + block_size=block_size, + ) + + monkeypatch.setattr(gather_scatter_ops, "_launch_triton_scatter_kernel", fake_launch) + + flagsparse_scatter( + dense, + indices, + vals, + reset_output=True, + dtype_policy="auto", + index_fallback_policy="auto", + ) + assert state["forced_once"] + assert torch.equal(dense, dense_ref) + + +@pytest.mark.scatter +def test_scatter_int64_strict_no_fallback(monkeypatch): + device = torch.device("cuda") + dense_size = 257 + nnz = 129 + vals = torch.randn(nnz, dtype=torch.float32, device=device) + indices = torch.randperm(dense_size, device=device)[:nnz].to(torch.int64) + dense = torch.randn(dense_size, dtype=torch.float32, device=device) + + original_launch = gather_scatter_ops._launch_triton_scatter_kernel + + def fake_launch(dense_values, sparse_values, kernel_indices, nnz, block_size=1024): + if kernel_indices.dtype == torch.int64: + raise RuntimeError("forced int64 launch failure") + return original_launch( + dense_values, + sparse_values, + kernel_indices, + nnz, + block_size=block_size, + ) + + monkeypatch.setattr(gather_scatter_ops, "_launch_triton_scatter_kernel", fake_launch) + + with pytest.raises(RuntimeError, match="Triton scatter failed for index dtype"): + flagsparse_scatter( + dense, + indices, + vals, + reset_output=True, + dtype_policy="auto", + index_fallback_policy="strict", + ) diff --git a/tests/pytest/test_spgemm_sddmm_accuracy.py b/tests/pytest/test_spgemm_sddmm_accuracy.py new file mode 100644 index 0000000..42c6019 --- /dev/null +++ b/tests/pytest/test_spgemm_sddmm_accuracy.py @@ -0,0 +1,100 @@ +import pytest +import torch + +from flagsparse import flagsparse_sddmm_csr, flagsparse_spgemm_csr + +from tests.pytest.param_shapes import ( + SDDMM_DTYPES, + SDDMM_DTYPE_IDS, + SDDMM_MNK_SHAPES, + SPGEMM_DTYPES, + SPGEMM_DTYPE_IDS, + SPGEMM_MNK_SHAPES, +) + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _random_csr(rows, cols, dtype, device): + denom = max(rows * cols, 1) + p = min(0.25, max(0.06, 32.0 / denom)) + mask = torch.rand(rows, cols, device=device) < p + if int(mask.sum().item()) == 0: + mask[0, 0] = True + vals = torch.randn(rows, cols, dtype=dtype, device=device) * mask.to(dtype=dtype) + return vals.to_sparse_csr() + + +def _csr_to_dense(data, indices, indptr, shape): + csr = torch.sparse_csr_tensor( + indptr, + indices, + data, + size=shape, + dtype=data.dtype, + device=data.device, + ) + return csr.to_dense() + + +def _tol(dtype): + if dtype == torch.float32: + return 1e-4, 1e-4 + return 1e-10, 1e-8 + + +@pytest.mark.spgemm_csr +@pytest.mark.parametrize("M, N, K", SPGEMM_MNK_SHAPES) +@pytest.mark.parametrize("dtype", SPGEMM_DTYPES, ids=SPGEMM_DTYPE_IDS) +def test_spgemm_csr_matches_torch(M, N, K, dtype): + device = torch.device("cuda") + A = _random_csr(M, K, dtype, device) + B = _random_csr(K, N, dtype, device) + c_data, c_indices, c_indptr, c_shape = flagsparse_spgemm_csr( + A.values(), + A.col_indices().to(torch.int32), + A.crow_indices(), + (M, K), + B.values(), + B.col_indices().to(torch.int32), + B.crow_indices(), + (K, N), + ) + got = _csr_to_dense(c_data, c_indices, c_indptr, c_shape) + ref = torch.sparse.mm(A, B.to_dense()) + rtol, atol = _tol(dtype) + assert torch.allclose(got, ref, rtol=rtol, atol=atol) + + +@pytest.mark.sddmm_csr +@pytest.mark.parametrize("M, N, K", SDDMM_MNK_SHAPES) +@pytest.mark.parametrize("dtype", SDDMM_DTYPES, ids=SDDMM_DTYPE_IDS) +def test_sddmm_csr_matches_sampled_dense_reference(M, N, K, dtype): + device = torch.device("cuda") + pattern = _random_csr(M, N, dtype, device) + indices = pattern.col_indices().to(torch.int32) + indptr = pattern.crow_indices() + data = pattern.values() + x = torch.randn(M, K, dtype=dtype, device=device) + y = torch.randn(N, K, dtype=dtype, device=device) + alpha = 1.25 + beta = 0.5 + + got = flagsparse_sddmm_csr( + data=data, + indices=indices, + indptr=indptr, + x=x, + y=y, + shape=(M, N), + alpha=alpha, + beta=beta, + ) + row_ids = torch.repeat_interleave( + torch.arange(M, dtype=torch.int64, device=device), + indptr[1:] - indptr[:-1], + ) + ref = alpha * torch.sum(x[row_ids] * y[indices.to(torch.int64)], dim=1) + beta * data + rtol, atol = _tol(dtype) + assert torch.allclose(got, ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spmm_coo_accuracy.py b/tests/pytest/test_spmm_coo_accuracy.py new file mode 100644 index 0000000..3cde262 --- /dev/null +++ b/tests/pytest/test_spmm_coo_accuracy.py @@ -0,0 +1,45 @@ +import pytest +import torch + +from flagsparse import flagsparse_spmm_coo + +from tests.pytest.param_shapes import MNK_SHAPES, SPMM_OPT_DTYPES, SPMM_OPT_DTYPE_IDS + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _random_coo_mk(M, K, dtype, device): + denom = max(M * K, 1) + p = min(0.25, max(0.06, 32.0 / denom)) + mask = torch.rand(M, K, device=device) < p + if int(mask.sum().item()) == 0: + mask[0, 0] = True + vals = torch.randn(M, K, dtype=dtype, device=device) * mask.to(dtype=dtype) + return vals.to_sparse_coo().coalesce() + + +def _tol(dtype): + if dtype == torch.float32: + return 1e-4, 1e-4 + return 1e-10, 1e-8 + + +@pytest.mark.spmm_coo +@pytest.mark.parametrize("M, N, K", MNK_SHAPES) +@pytest.mark.parametrize("dtype", SPMM_OPT_DTYPES, ids=SPMM_OPT_DTYPE_IDS) +def test_spmm_coo_matches_torch(M, N, K, dtype): + device = torch.device("cuda") + Asp = _random_coo_mk(M, K, dtype, device) + indices = Asp.indices() + data = Asp.values() + row = indices[0].contiguous() + col = indices[1].contiguous() + B = torch.randn(K, N, dtype=dtype, device=device) + if dtype == torch.float32: + ref = torch.sparse.mm(Asp.double(), B.double()).float() + else: + ref = torch.sparse.mm(Asp, B) + out = flagsparse_spmm_coo(data, row, col, B, (M, K)) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spmm_csr_accuracy.py b/tests/pytest/test_spmm_csr_accuracy.py new file mode 100644 index 0000000..1e3dbc0 --- /dev/null +++ b/tests/pytest/test_spmm_csr_accuracy.py @@ -0,0 +1,58 @@ +import pytest +import torch + +from flagsparse import flagsparse_spmm_csr + +from tests.pytest.param_shapes import MNK_SHAPES, SPMM_FLOAT_DTYPES, SPMM_FLOAT_DTYPE_IDS + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _random_csr_mk(M, K, dtype, device): + denom = max(M * K, 1) + p = min(0.25, max(0.06, 32.0 / denom)) + mask = torch.rand(M, K, device=device) < p + if int(mask.sum().item()) == 0: + mask[0, 0] = True + vals = torch.randn(M, K, dtype=dtype, device=device) * mask.to(dtype=dtype) + return vals.to_sparse_csr() + + +def _tol(dtype): + if dtype == torch.bfloat16: + return 1e-1, 1e-1 + if dtype == torch.float32: + return 1e-4, 1e-4 + return 1e-10, 1e-8 + + +@pytest.mark.spmm_csr +@pytest.mark.parametrize("M, N, K", MNK_SHAPES) +@pytest.mark.parametrize("dtype", SPMM_FLOAT_DTYPES, ids=SPMM_FLOAT_DTYPE_IDS) +def test_spmm_csr_matches_torch(M, N, K, dtype): + if dtype == torch.bfloat16 and not ( + torch.cuda.is_available() and torch.cuda.is_bf16_supported() + ): + pytest.skip("bfloat16 not supported on this GPU") + device = torch.device("cuda") + Asp = _random_csr_mk(M, K, dtype, device) + data = Asp.values() + indices = Asp.col_indices() + indptr = Asp.crow_indices() + B = torch.randn(K, N, dtype=dtype, device=device) + if dtype == torch.float32: + Asp64 = torch.sparse_csr_tensor( + crow_indices=indptr, + col_indices=indices, + values=data.double(), + size=(M, K), + dtype=torch.float64, + device=device, + ) + ref = torch.sparse.mm(Asp64, B.double()).float() + else: + ref = torch.sparse.mm(Asp, B) + out = flagsparse_spmm_csr(data, indices, indptr, B, (M, K)) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spmv_coo_accuracy.py b/tests/pytest/test_spmv_coo_accuracy.py new file mode 100644 index 0000000..578eb20 --- /dev/null +++ b/tests/pytest/test_spmv_coo_accuracy.py @@ -0,0 +1,46 @@ +import pytest +import torch + +from flagsparse import flagsparse_spmv_coo + +from tests.pytest.param_shapes import ( + SPMV_COO_DTYPES, + SPMV_COO_DTYPE_IDS, + SPMV_MN_SHAPES, +) + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _random_coo_mn(M, N, dtype, device): + denom = max(M * N, 1) + p = min(0.25, max(0.06, 32.0 / denom)) + mask = torch.rand(M, N, device=device) < p + if int(mask.sum().item()) == 0: + mask[0, 0] = True + vals = torch.randn(M, N, dtype=dtype, device=device) * mask.to(dtype=dtype) + return vals.to_sparse_coo() + + +def _tol(dtype): + if dtype == torch.float32: + return 1e-4, 1e-4 + return 1e-10, 1e-8 + + +@pytest.mark.spmv_coo +@pytest.mark.parametrize("M, N", SPMV_MN_SHAPES) +@pytest.mark.parametrize("dtype", SPMV_COO_DTYPES, ids=SPMV_COO_DTYPE_IDS) +def test_spmv_coo_matches_torch(M, N, dtype): + device = torch.device("cuda") + Asp = _random_coo_mn(M, N, dtype, device) + data = Asp.values() + indices = Asp.indices() + row = indices[0].contiguous() + col = indices[1].contiguous() + x = torch.randn(N, dtype=dtype, device=device) + ref = torch.sparse.mm(Asp, x.unsqueeze(1)).squeeze(1) + out = flagsparse_spmv_coo(data, row, col, x, shape=(M, N)) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spmv_csr_accuracy.py b/tests/pytest/test_spmv_csr_accuracy.py new file mode 100644 index 0000000..c547619 --- /dev/null +++ b/tests/pytest/test_spmv_csr_accuracy.py @@ -0,0 +1,49 @@ +import pytest +import torch + +from flagsparse import flagsparse_spmv_csr + +from tests.pytest.param_shapes import FLOAT_DTYPE_IDS, FLOAT_DTYPES, SPMV_MN_SHAPES + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _random_csr_mn(M, N, dtype, device): + denom = max(M * N, 1) + p = min(0.25, max(0.06, 32.0 / denom)) + mask = torch.rand(M, N, device=device) < p + if int(mask.sum().item()) == 0: + mask[0, 0] = True + vals = torch.randn(M, N, dtype=dtype, device=device) * mask.to(dtype=dtype) + return vals.to_sparse_csr() + + +def _tol(dtype): + if dtype == torch.float16: + return 5e-3, 5e-3 + if dtype == torch.bfloat16: + return 1e-1, 1e-1 + if dtype == torch.float32: + return 1e-4, 1e-4 + return 1e-10, 1e-8 + + +@pytest.mark.spmv_csr +@pytest.mark.parametrize("M, N", SPMV_MN_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES, ids=FLOAT_DTYPE_IDS) +def test_spmv_csr_matches_torch(M, N, dtype): + if dtype == torch.bfloat16 and not ( + torch.cuda.is_available() and torch.cuda.is_bf16_supported() + ): + pytest.skip("bfloat16 not supported on this GPU") + device = torch.device("cuda") + Asp = _random_csr_mn(M, N, dtype, device) + data = Asp.values() + indices = Asp.col_indices() + indptr = Asp.crow_indices() + x = torch.randn(N, dtype=dtype, device=device) + ref = torch.sparse.mm(Asp, x.unsqueeze(1)).squeeze(1) + out = flagsparse_spmv_csr(data, indices, indptr, x, shape=(M, N)) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spsm_accuracy.py b/tests/pytest/test_spsm_accuracy.py new file mode 100644 index 0000000..76e11e0 --- /dev/null +++ b/tests/pytest/test_spsm_accuracy.py @@ -0,0 +1,68 @@ +import pytest +import torch + +from flagsparse import flagsparse_spsm_coo, flagsparse_spsm_csr + +from tests.pytest.param_shapes import SPSM_N_RHS, TRIANGULAR_DTYPE_IDS, TRIANGULAR_DTYPES + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + +def _build_lower_dense(n, dtype, device): + base = torch.tril(torch.randn(n, n, dtype=dtype, device=device)) + eye = torch.eye(n, dtype=dtype, device=device) + return base + eye * (float(n) * 0.5 + 2.0) + + +def _tol(dtype): + if dtype == torch.float32: + return 1e-4, 1e-5 + return 1e-10, 1e-10 + + +@pytest.mark.spsm +@pytest.mark.spsm_csr +@pytest.mark.parametrize("n, n_rhs", SPSM_N_RHS) +@pytest.mark.parametrize("dtype", TRIANGULAR_DTYPES, ids=TRIANGULAR_DTYPE_IDS) +def test_spsm_csr_lower_matches_dense(n, n_rhs, dtype): + device = torch.device("cuda") + A = _build_lower_dense(n, dtype, device) + B = torch.randn(n, n_rhs, dtype=dtype, device=device) + ref = torch.linalg.solve_triangular(A, B, upper=False) + Acsr = A.to_sparse_csr() + out = flagsparse_spsm_csr( + Acsr.values(), + Acsr.col_indices(), + Acsr.crow_indices(), + B, + (n, n), + lower=True, + unit_diagonal=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsm +@pytest.mark.spsm_coo +@pytest.mark.parametrize("n, n_rhs", SPSM_N_RHS) +@pytest.mark.parametrize("dtype", TRIANGULAR_DTYPES, ids=TRIANGULAR_DTYPE_IDS) +def test_spsm_coo_lower_matches_dense(n, n_rhs, dtype): + device = torch.device("cuda") + A = _build_lower_dense(n, dtype, device) + B = torch.randn(n, n_rhs, dtype=dtype, device=device) + ref = torch.linalg.solve_triangular(A, B, upper=False) + Acoo = A.to_sparse_coo().coalesce() + indices = Acoo.indices() + out = flagsparse_spsm_coo( + Acoo.values(), + indices[0].contiguous(), + indices[1].contiguous(), + B, + (n, n), + lower=True, + unit_diagonal=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(out, ref, rtol=rtol, atol=atol) diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py new file mode 100644 index 0000000..fc9c891 --- /dev/null +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -0,0 +1,447 @@ +import pytest +import torch + +from flagsparse import flagsparse_spsv_coo, flagsparse_spsv_csr + +from tests.pytest.param_shapes import SPSV_N + +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse + from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular +except Exception: + cp = None + cpx_sparse = None + cpx_spsolve_triangular = None + + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +SUPPORTED_COMPLEX_DTYPES = [torch.complex64, torch.complex128] + +SUPPORTED_DTYPES = [torch.float32, torch.float64, *SUPPORTED_COMPLEX_DTYPES] +NON_TRANS_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +TRANS_CONJ_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +TRANS_CONJ_MODES = ["TRANS", "CONJ"] + + +def _dtype_id(dtype): + return str(dtype).replace("torch.", "") + + +def _tol(dtype): + if dtype in (torch.float32, torch.complex64): + return 1e-4, 1e-3 + return 1e-10, 1e-8 + + +def _rand_like(dtype, shape, device): + if dtype in (torch.float32, torch.float64): + return torch.randn(shape, dtype=dtype, device=device) + base = torch.float32 if dtype == torch.complex64 else torch.float64 + r = torch.randn(shape, dtype=base, device=device) + i = torch.randn(shape, dtype=base, device=device) + return torch.complex(r, i) + + +def _ref_dtype(dtype): + return dtype + + +def _safe_cast_tensor(tensor, dtype): + return tensor.to(dtype) + + +def _cmp_view(tensor, dtype): + return tensor + + +def _apply_ref_op(A, op_mode): + if op_mode == "TRANS": + return A.transpose(-2, -1) + if op_mode == "CONJ": + return A.transpose(-2, -1).conj() if torch.is_complex(A) else A.transpose(-2, -1) + return A + + +def _effective_upper(lower, op_mode): + return lower if op_mode in ("TRANS", "CONJ") else not lower + + +def _effective_lower_for_op(lower, op_mode): + return (not lower) if op_mode in ("TRANS", "CONJ") else lower + + +def _transpose_arg(op_mode): + if op_mode == "NON": + return False + return op_mode + + +def _cupy_apply_op(A_cp, op_mode): + if op_mode == "TRANS": + return A_cp.transpose().tocsr() + if op_mode == "CONJ": + return A_cp.transpose().conj().tocsr() + return A_cp + + +def _build_triangular(n, dtype, device, lower=True): + off = _rand_like(dtype, (n, n), device) * 0.02 + A = torch.tril(off) if lower else torch.triu(off) + if torch.is_complex(A): + diag = (torch.rand(n, device=device, dtype=A.real.dtype) + 2.0).to(A.real.dtype) + A = A + torch.diag(torch.complex(diag, torch.zeros_like(diag))) + else: + diag = torch.rand(n, device=device, dtype=A.dtype) + 2.0 + A = A + torch.diag(diag) + return A + + +def _cupy_csr_from_torch(data, indices, indptr, shape): + if cp is None or cpx_sparse is None: + return None + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) + idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous())) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous())) + return cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + + +def _cupy_ref_spsv(A_cp, b_t, *, lower, unit_diagonal=False): + if cp is None or cpx_spsolve_triangular is None: + return None + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_t.contiguous())) + x_cp = cpx_spsolve_triangular(A_cp, b_cp, lower=lower, unit_diagonal=unit_diagonal) + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()) + return x_t.to(b_t.dtype) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize( + "dtype", + [torch.float32, torch.float64], + ids=["float32", "float64"], +) +def test_spsv_csr_lower_matches_dense(n, dtype): + # Keep the original baseline test case untouched in semantics. + device = torch.device("cuda") + base = torch.tril(torch.randn(n, n, dtype=dtype, device=device)) + eye = torch.eye(n, dtype=dtype, device=device) + A = base + eye * (float(n) * 0.5 + 2.0) + b = torch.randn(n, dtype=dtype, device=device) + x_ref = torch.linalg.solve_triangular( + A, b.unsqueeze(-1), upper=False + ).squeeze(-1) + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices() + indptr = Asp.crow_indices().to(torch.int64) + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + ) + rtol = 1e-4 if dtype == torch.float32 else 1e-10 + atol = 1e-5 if dtype == torch.float32 else 1e-10 + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_transpose_family_supported_combos(n, dtype, index_dtype, op_mode): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + A_ref = A.to(_ref_dtype(dtype)) + b_ref = b.to(_ref_dtype(dtype)) + x_ref = torch.linalg.solve_triangular( + _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(True, op_mode) + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_non_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_non = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=True, unit_diagonal=False, transpose=False + ) + x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_matches_cusparse_transpose_family(n, dtype, index_dtype, op_mode): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_trans = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_trans_ref = _cupy_ref_spsv( + _cupy_apply_op(A_cp, op_mode), + b, + lower=_effective_lower_for_op(True, op_mode), + unit_diagonal=False, + ) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +def test_spsv_csr_non_trans_upper_supported_combos(n, dtype, index_dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=True + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=False, + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_upper_transpose_family_supported_combos(n, dtype, index_dtype, op_mode): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + A_ref = A.to(_ref_dtype(dtype)) + b_ref = b.to(_ref_dtype(dtype)) + x_ref = torch.linalg.solve_triangular( + _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(False, op_mode) + ).squeeze(-1) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + + x = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", NON_TRANS_DTYPES, ids=_dtype_id) +def test_spsv_csr_matches_cusparse_upper_non_trans(n, dtype): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(torch.int32) + indptr = Asp.crow_indices().to(torch.int32) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_non = flagsparse_spsv_csr( + data, indices, indptr, b, (n, n), lower=False, unit_diagonal=False, transpose=False + ) + x_non_ref = _cupy_ref_spsv(A_cp, b, lower=False, unit_diagonal=False) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.skipif( + cp is None or cpx_sparse is None or cpx_spsolve_triangular is None, + reason="CuPy/cuSPARSE required", +) +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("dtype", TRANS_CONJ_DTYPES, ids=_dtype_id) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64], ids=["int32", "int64"]) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_csr_matches_cusparse_upper_transpose_family(n, dtype, index_dtype, op_mode): + device = torch.device("cuda") + A = _build_triangular(n, dtype, device, lower=False) + b = _rand_like(dtype, (n,), device) + + Asp = A.to_sparse_csr() + data = Asp.values() + indices = Asp.col_indices().to(index_dtype) + indptr = Asp.crow_indices().to(index_dtype) + A_cp = _cupy_csr_from_torch(data, indices, indptr, (n, n)) + + x_trans = flagsparse_spsv_csr( + data, + indices, + indptr, + b, + (n, n), + lower=False, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + ) + x_trans_ref = _cupy_ref_spsv( + _cupy_apply_op(A_cp, op_mode), + b, + lower=_effective_lower_for_op(False, op_mode), + unit_diagonal=False, + ) + + rtol, atol = _tol(dtype) + assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + + +@pytest.mark.spsv +@pytest.mark.parametrize("n", SPSV_N) +@pytest.mark.parametrize("op_mode", TRANS_CONJ_MODES) +def test_spsv_coo_transpose_family_complex128_routes_through_csr(n, op_mode): + device = torch.device("cuda") + dtype = torch.complex128 + A = _build_triangular(n, dtype, device, lower=True) + b = _rand_like(dtype, (n,), device) + x_ref = torch.linalg.solve_triangular( + _apply_ref_op(A, op_mode), b.unsqueeze(-1), upper=_effective_upper(True, op_mode) + ).squeeze(-1) + + A_coo = A.to_sparse_coo().coalesce() + row, col = A_coo.indices() + data = A_coo.values() + + x = flagsparse_spsv_coo( + data, + row.to(torch.int32), + col.to(torch.int32), + b, + (n, n), + lower=True, + unit_diagonal=False, + transpose=_transpose_arg(op_mode), + coo_mode="auto", + ) + rtol, atol = _tol(dtype) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) diff --git a/tests/test_gather.py b/tests/test_gather.py new file mode 100644 index 0000000..9508a02 --- /dev/null +++ b/tests/test_gather.py @@ -0,0 +1,548 @@ +import argparse +import csv +import math +import os +import time + +import torch + +import flagsparse as ast +from flagsparse.sparse_operations._common import cp + + +DEFAULT_CASES = [ + (32_768, 1_024), + (131_072, 4_096), + (524_288, 16_384), + (1_048_576, 65_536), +] +DEFAULT_VALUE_DTYPES = "float16,bfloat16,float32,float64,complex64,complex128" +DEFAULT_INDEX_DTYPES = "int32,int64" +WARMUP = 20 +ITERS = 200 + + +def _fmt_ms(value): + if value is None: + return "N/A" + return f"{value:.4f}" + + +def _fmt_speedup(value): + if value is None: + return "N/A" + if math.isinf(value): + return "inf" + return f"{value:.2f}x" + + +def _fmt_err(value): + if value is None: + return "N/A" + return f"{value:.2e}" + + +def _parse_value_dtypes(raw): + allowed = { + "float16", + "bfloat16", + "float32", + "float64", + "complex64", + "complex128", + } + tokens = [tok.strip().lower() for tok in str(raw).split(",") if tok.strip()] + if not tokens: + raise ValueError("value dtypes list is empty") + invalid = [tok for tok in tokens if tok not in allowed] + if invalid: + raise ValueError(f"unsupported value dtypes: {invalid}") + return tokens + + +def _parse_index_dtypes(raw): + mapping = {"int32": torch.int32, "int64": torch.int64} + tokens = [tok.strip().lower() for tok in str(raw).split(",") if tok.strip()] + if not tokens: + raise ValueError("index dtypes list is empty") + invalid = [tok for tok in tokens if tok not in mapping] + if invalid: + raise ValueError(f"unsupported index dtypes: {invalid}") + return [(tok, mapping[tok]) for tok in tokens] + + +def _parse_cases(raw): + if raw is None or not str(raw).strip(): + return list(DEFAULT_CASES) + pairs = [] + for chunk in str(raw).split(","): + item = chunk.strip() + if not item: + continue + if ":" not in item: + raise ValueError(f"invalid case '{item}', expected dense:nnz") + left, right = item.split(":", 1) + dense_size = int(left) + nnz = int(right) + if dense_size < 0 or nnz < 0: + raise ValueError(f"case values must be non-negative: {item}") + pairs.append((dense_size, nnz)) + if not pairs: + raise ValueError("case list is empty") + return pairs + + +def _ensure_parent_dir(path): + parent = os.path.dirname(os.path.abspath(path)) + if parent: + os.makedirs(parent, exist_ok=True) + + +def _write_csv(path, rows, fieldnames): + _ensure_parent_dir(path) + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({k: ("" if v is None else v) for k, v in row.items()}) + + +def _sync_all(): + torch.cuda.synchronize() + if cp is not None: + cp.cuda.runtime.deviceSynchronize() + + +def _bench_cuda_op(op, warmup, iters): + warmup = max(0, int(warmup)) + iters = max(1, int(iters)) + output = None + for _ in range(warmup): + output = op() + _sync_all() + start = time.perf_counter() + for _ in range(iters): + output = op() + _sync_all() + elapsed_ms = (time.perf_counter() - start) * 1000.0 / iters + return output, elapsed_ms + + +def _to_real_imag(value): + if hasattr(value, "is_complex") and value.is_complex(): + return float(value.real.item()), float(value.imag.item()) + if getattr(value, "ndim", 0) == 1 and int(value.numel()) == 2: + return float(value[0].item()), float(value[1].item()) + scalar = float(value.item()) + return scalar, 0.0 + + +def _collect_samples(case_id, expected, flagsparse_out, limit): + rows = [] + if expected is None or flagsparse_out is None: + return rows + max_items = min(int(limit), int(expected.shape[0]), int(flagsparse_out.shape[0])) + for pos in range(max_items): + exp_val = expected[pos] + fs_val = flagsparse_out[pos] + exp_real, exp_imag = _to_real_imag(exp_val) + fs_real, fs_imag = _to_real_imag(fs_val) + abs_error = float(torch.max(torch.abs(fs_val - exp_val)).item()) + rows.append( + { + "case_id": case_id, + "pos": pos, + "expected_real": exp_real, + "expected_imag": exp_imag, + "flagsparse_real": fs_real, + "flagsparse_imag": fs_imag, + "abs_error": abs_error, + } + ) + return rows + + +def _dtype_mode(value_dtype_req): + _ = value_dtype_req + return "gather_triton" + + +def _select_mode(value_dtype_req, index_dtype): + _ = index_dtype + return _dtype_mode(value_dtype_req) + + +def _build_dense(value_dtype_req, dense_size, device): + if value_dtype_req == "float16": + return torch.randn(dense_size, dtype=torch.float16, device=device) + if value_dtype_req == "bfloat16": + return torch.randn(dense_size, dtype=torch.bfloat16, device=device) + if value_dtype_req == "float32": + return torch.randn(dense_size, dtype=torch.float32, device=device) + if value_dtype_req == "float64": + return torch.randn(dense_size, dtype=torch.float64, device=device) + if value_dtype_req == "complex64": + real = torch.randn(dense_size, dtype=torch.float32, device=device) + imag = torch.randn(dense_size, dtype=torch.float32, device=device) + return torch.complex(real, imag) + if value_dtype_req == "complex128": + real = torch.randn(dense_size, dtype=torch.float64, device=device) + imag = torch.randn(dense_size, dtype=torch.float64, device=device) + return torch.complex(real, imag) + raise ValueError(f"Unsupported value dtype request: {value_dtype_req}") + + +def _effective_dtype_name(value_dtype_req): + mapping = { + "float16": "float16", + "bfloat16": "bfloat16", + "float32": "float32", + "float64": "float64", + "complex64": "complex64", + "complex128": "complex128", + } + return mapping[value_dtype_req] + + +def _tolerance(value_dtype_req): + if value_dtype_req == "float16": + return 5e-3, 5e-3 + if value_dtype_req in ("bfloat16",): + return 1e-2, 1e-2 + if value_dtype_req in ("float32", "complex64"): + return 1e-6, 1e-5 + if value_dtype_req in ("float64", "complex128"): + return 1e-10, 1e-8 + return 1e-6, 1e-5 + + +def _check_dtype_supported(value_dtype_req): + if value_dtype_req in ("bfloat16",) and not torch.cuda.is_bf16_supported(): + raise RuntimeError("bfloat16 not supported on this GPU") + + +def _is_supported_gather_combo(index_dtype): + # Required gather coverage is the full 6 value dtypes x 2 index dtypes matrix. + return index_dtype in (torch.int32, torch.int64) + + +def _build_indices(dense_size, nnz, index_dtype, device): + return torch.randint(0, dense_size, (nnz,), dtype=index_dtype, device=device) + + +def _benchmark_gather_case( + value_dtype_req, + index_dtype, + dense_size, + nnz, + warmup, + iters, + run_cusparse, + index_fallback_policy, +): + device = torch.device("cuda") + _check_dtype_supported(value_dtype_req) + dense_vector = _build_dense(value_dtype_req, dense_size, device) + indices = _build_indices(dense_size, nnz, index_dtype, device) + expected = dense_vector.index_select(0, indices.to(torch.int64)) + + mode = _select_mode(value_dtype_req, index_dtype) + gather_meta = { + "index_fallback_applied": False, + "index_fallback_reason": None, + "kernel_index_dtype": str(index_dtype).replace("torch.", ""), + } + flagsparse_op = lambda: ast.flagsparse_gather(dense_vector, indices) + cusparse_op = lambda: ast.cusparse_spmv_gather(dense_vector, indices)[0] + + pytorch_op = lambda: dense_vector.index_select(0, indices.to(torch.int64)) + pytorch_values, pytorch_ms = _bench_cuda_op(pytorch_op, warmup=warmup, iters=iters) + flagsparse_values, flagsparse_ms = _bench_cuda_op(flagsparse_op, warmup=warmup, iters=iters) + + atol, rtol = _tolerance(value_dtype_req) + fs_match = torch.allclose(flagsparse_values, expected, atol=atol, rtol=rtol) + fs_max_error = ( + float(torch.max(torch.abs(flagsparse_values - expected)).item()) + if expected.numel() > 0 + else 0.0 + ) + + cusparse_values = None + cusparse_ms = None + cusparse_match = None + cusparse_max_error = None + cusparse_reason = None + if run_cusparse: + try: + cusparse_values, cusparse_ms = _bench_cuda_op(cusparse_op, warmup=warmup, iters=iters) + cusparse_match = torch.allclose(cusparse_values, expected, atol=atol, rtol=rtol) + cusparse_max_error = ( + float(torch.max(torch.abs(cusparse_values - expected)).item()) + if expected.numel() > 0 + else 0.0 + ) + except Exception as exc: + cusparse_reason = str(exc) + + fs_vs_pt = pytorch_ms / flagsparse_ms if flagsparse_ms > 0 else float("inf") + fs_vs_cs = ( + cusparse_ms / flagsparse_ms + if (cusparse_ms is not None and flagsparse_ms > 0) + else None + ) + + return { + "parameters": { + "dense_size": dense_size, + "nnz": int(indices.numel()), + "value_dtype": value_dtype_req, + "effective_value_dtype": _effective_dtype_name(value_dtype_req), + "index_dtype": str(index_dtype), + "mode": mode, + }, + "performance": { + "pytorch_ms": pytorch_ms, + "triton_ms": flagsparse_ms, + "cusparse_ms": cusparse_ms, + "triton_speedup_vs_pytorch": fs_vs_pt, + "triton_speedup_vs_cusparse": fs_vs_cs, + }, + "verification": { + "triton_match_pytorch": fs_match, + "triton_max_error": fs_max_error, + "cusparse_match_pytorch": cusparse_match, + "cusparse_max_error": cusparse_max_error, + }, + "backend_status": { + "cusparse_unavailable_reason": cusparse_reason, + "index_fallback_applied": bool(gather_meta.get("index_fallback_applied")), + "index_fallback_reason": gather_meta.get("index_fallback_reason"), + }, + "samples": { + "pytorch": pytorch_values, + "triton": flagsparse_values, + }, + } + + +def _status_from_result(verification): + triton_ok = verification.get("triton_match_pytorch") + cusparse_ok = verification.get("cusparse_match_pytorch") + overall_ok = bool(triton_ok) and (cusparse_ok is None or bool(cusparse_ok)) + return "PASS" if overall_ok else "FAIL" + + +def _print_header(): + print("-" * 196) + print( + f"{'ValueReq':>14} {'ValueEff':>18} {'Index':>6} {'Dense':>10} {'NNZ':>10} " + f"{'IFB':>4} {'PT(ms)':>10} {'FS(ms)':>10} {'CS(ms)':>10} " + f"{'FS/PT':>8} {'FS/CS':>8} {'Status':>6} {'Err(FS)':>12} {'Err(CS)':>12}" + ) + print("-" * 196) + + +def _print_row(row): + print( + f"{row['value_dtype_req']:>14} {row['value_dtype_compute']:>18} {row['index_dtype']:>6} " + f"{row['dense_size']:>10,d} {row['nnz']:>10,d} {str(row['index_fallback_applied']):>4} " + f"{_fmt_ms(row['pytorch_ms']):>10} {_fmt_ms(row['triton_ms']):>10} {_fmt_ms(row['cusparse_ms']):>10} " + f"{_fmt_speedup(row['triton_speedup_vs_pytorch']):>8} {_fmt_speedup(row['triton_speedup_vs_cusparse']):>8} " + f"{row['status']:>6} {_fmt_err(row['triton_max_error']):>12} {_fmt_err(row['cusparse_max_error']):>12}" + ) + + +def run_cli(args): + if not torch.cuda.is_available(): + print("CUDA is not available. Please run on a GPU-enabled system.") + return + + value_dtype_tokens = _parse_value_dtypes(args.value_dtypes) + index_dtype_pairs = _parse_index_dtypes(args.index_dtypes) + run_cusparse = not args.no_cusparse + work_cases = _parse_cases(args.cases) + + print("=" * 180) + print("FLAGSPARSE GATHER BENCHMARK/VALIDATION") + print("=" * 180) + print(f"GPU: {torch.cuda.get_device_name(0)}") + print( + f"Warmup: {args.warmup} | Iterations: {args.iters} | " + f"index_fallback_policy: {args.index_fallback_policy}" + ) + print() + _print_header() + + summary_rows = [] + sample_rows = [] + total_cases = 0 + failed_cases = 0 + + for value_dtype in value_dtype_tokens: + for index_name, index_dtype in index_dtype_pairs: + if not _is_supported_gather_combo(index_dtype): + continue + for dense_size, nnz in work_cases: + dense_size = int(dense_size) + nnz = int(nnz) + case_id = ( + f"{value_dtype}|{index_name}|dense={dense_size}|nnz={nnz}|" + f"ifb_policy={args.index_fallback_policy}" + ) + total_cases += 1 + try: + result = _benchmark_gather_case( + value_dtype_req=value_dtype, + index_dtype=index_dtype, + dense_size=dense_size, + nnz=nnz, + warmup=args.warmup, + iters=args.iters, + run_cusparse=run_cusparse, + index_fallback_policy=args.index_fallback_policy, + ) + perf = result["performance"] + verify = result["verification"] + params = result["parameters"] + backend = result["backend_status"] + status = _status_from_result(verify) + if status != "PASS": + failed_cases += 1 + row = { + "case_id": case_id, + "gpu": torch.cuda.get_device_name(0), + "value_dtype_req": params.get("value_dtype"), + "value_dtype_compute": str(params.get("effective_value_dtype")), + "index_dtype": str(params.get("index_dtype")).replace("torch.", ""), + "dense_size": int(params.get("dense_size")), + "nnz": int(params.get("nnz")), + "mode": params.get("mode"), + "index_fallback_policy": args.index_fallback_policy, + "index_fallback_applied": bool( + backend.get("index_fallback_applied") + ), + "triton_ms": perf.get("triton_ms"), + "pytorch_ms": perf.get("pytorch_ms"), + "cusparse_ms": perf.get("cusparse_ms"), + "triton_speedup_vs_pytorch": perf.get("triton_speedup_vs_pytorch"), + "triton_speedup_vs_cusparse": perf.get("triton_speedup_vs_cusparse"), + "triton_match_pytorch": verify.get("triton_match_pytorch"), + "cusparse_match_pytorch": verify.get("cusparse_match_pytorch"), + "triton_max_error": verify.get("triton_max_error"), + "cusparse_max_error": verify.get("cusparse_max_error"), + "cusparse_unavailable_reason": backend.get("cusparse_unavailable_reason"), + "index_fallback_reason": backend.get("index_fallback_reason"), + "status": status, + } + summary_rows.append(row) + _print_row(row) + + if args.csv_samples: + sample_rows.extend( + _collect_samples( + case_id, + result["samples"].get("pytorch"), + result["samples"].get("triton"), + args.sample_limit, + ) + ) + except Exception as exc: + failed_cases += 1 + row = { + "case_id": case_id, + "gpu": torch.cuda.get_device_name(0), + "value_dtype_req": value_dtype, + "value_dtype_compute": "N/A", + "index_dtype": index_name, + "dense_size": dense_size, + "nnz": nnz, + "mode": _select_mode(value_dtype, index_dtype), + "index_fallback_policy": args.index_fallback_policy, + "index_fallback_applied": False, + "triton_ms": None, + "pytorch_ms": None, + "cusparse_ms": None, + "triton_speedup_vs_pytorch": None, + "triton_speedup_vs_cusparse": None, + "triton_match_pytorch": None, + "cusparse_match_pytorch": None, + "triton_max_error": None, + "cusparse_max_error": None, + "cusparse_unavailable_reason": str(exc), + "index_fallback_reason": str(exc), + "status": "ERROR", + } + summary_rows.append(row) + _print_row(row) + + print("-" * 196) + print(f"Total cases: {total_cases}") + print(f"Failed cases: {failed_cases}") + print(f"Passed cases: {total_cases - failed_cases}") + + if args.csv_summary: + summary_fields = [ + "case_id", + "gpu", + "value_dtype_req", + "value_dtype_compute", + "index_dtype", + "dense_size", + "nnz", + "mode", + "index_fallback_policy", + "index_fallback_applied", + "triton_ms", + "pytorch_ms", + "cusparse_ms", + "triton_speedup_vs_pytorch", + "triton_speedup_vs_cusparse", + "triton_match_pytorch", + "cusparse_match_pytorch", + "triton_max_error", + "cusparse_max_error", + "cusparse_unavailable_reason", + "index_fallback_reason", + "status", + ] + _write_csv(args.csv_summary, summary_rows, summary_fields) + print(f"Wrote summary CSV: {args.csv_summary}") + + if args.csv_samples: + sample_fields = [ + "case_id", + "pos", + "expected_real", + "expected_imag", + "flagsparse_real", + "flagsparse_imag", + "abs_error", + ] + _write_csv(args.csv_samples, sample_rows, sample_fields) + print(f"Wrote samples CSV: {args.csv_samples}") + + +def build_parser(): + parser = argparse.ArgumentParser( + description="Gather benchmark/validation aligned with scatter-style CLI and CSV export." + ) + parser.add_argument("--value-dtypes", default=DEFAULT_VALUE_DTYPES) + parser.add_argument("--index-dtypes", default=DEFAULT_INDEX_DTYPES) + parser.add_argument( + "--cases", + default=",".join(f"{dense}:{nnz}" for dense, nnz in DEFAULT_CASES), + help="Comma-separated dense:nnz pairs, e.g. 32768:1024,131072:4096", + ) + parser.add_argument("--warmup", type=int, default=WARMUP) + parser.add_argument("--iters", type=int, default=ITERS) + parser.add_argument("--no-cusparse", action="store_true") + parser.add_argument("--index-fallback-policy", choices=["auto", "strict"], default="auto") + parser.add_argument("--csv-summary", default=None) + parser.add_argument("--csv-samples", default=None) + parser.add_argument("--sample-limit", type=int, default=32) + return parser + + +if __name__ == "__main__": + cli_parser = build_parser() + run_cli(cli_parser.parse_args()) diff --git a/tests/test_scatter.py b/tests/test_scatter.py new file mode 100644 index 0000000..604adeb --- /dev/null +++ b/tests/test_scatter.py @@ -0,0 +1,397 @@ +import argparse +import csv +import math +import os + +import torch + +import flagsparse as ast + +DEFAULT_CASES = [ + (32_768, 1_024), + (131_072, 4_096), + (524_288, 16_384), + (1_048_576, 65_536), +] +DEFAULT_VALUE_DTYPES = "float16,float32,float64" +DEFAULT_INDEX_DTYPES = "int32,int64" +WARMUP = 20 +ITERS = 200 + + +def _fmt_ms(value): + if value is None: + return "N/A" + return f"{value:.4f}" + + +def _fmt_speedup(value): + if value is None: + return "N/A" + if math.isinf(value): + return "inf" + return f"{value:.2f}x" + + +def _fmt_err(value): + if value is None: + return "N/A" + return f"{value:.2e}" + + +def _parse_bool_token(raw, name): + token = str(raw).strip().lower() + if token in ("1", "true", "yes", "y", "on"): + return True + if token in ("0", "false", "no", "n", "off"): + return False + raise ValueError(f"{name} must be true/false, got: {raw}") + + +def _parse_value_dtypes(raw): + allowed = { + "float16", + "bfloat16", + "float32", + "float64", + "complex64", + "complex128", + } + tokens = [tok.strip().lower() for tok in str(raw).split(",") if tok.strip()] + if not tokens: + raise ValueError("value dtypes list is empty") + invalid = [tok for tok in tokens if tok not in allowed] + if invalid: + raise ValueError(f"unsupported value dtypes: {invalid}") + return tokens + + +def _parse_index_dtypes(raw): + mapping = {"int32": torch.int32, "int64": torch.int64} + tokens = [tok.strip().lower() for tok in str(raw).split(",") if tok.strip()] + if not tokens: + raise ValueError("index dtypes list is empty") + invalid = [tok for tok in tokens if tok not in mapping] + if invalid: + raise ValueError(f"unsupported index dtypes: {invalid}") + return [(tok, mapping[tok]) for tok in tokens] + + +def _parse_cases(raw): + if raw is None or not str(raw).strip(): + return list(DEFAULT_CASES) + pairs = [] + for chunk in str(raw).split(","): + item = chunk.strip() + if not item: + continue + if ":" not in item: + raise ValueError(f"invalid case '{item}', expected dense:nnz") + left, right = item.split(":", 1) + dense_size = int(left) + nnz = int(right) + if dense_size < 0 or nnz < 0: + raise ValueError(f"case values must be non-negative: {item}") + pairs.append((dense_size, nnz)) + if not pairs: + raise ValueError("case list is empty") + return pairs + + +def _iter_reset_output_modes(raw): + token = str(raw).strip().lower() + if token == "both": + return [True, False] + if token == "true": + return [True] + if token == "false": + return [False] + raise ValueError("reset-output must be one of: true, false, both") + + +def _status_from_result(verification): + triton_ok = verification.get("triton_match_pytorch") + cusparse_ok = verification.get("cusparse_match_pytorch") + overall_ok = bool(triton_ok) and (cusparse_ok is None or bool(cusparse_ok)) + return "PASS" if overall_ok else "FAIL" + + +def _to_real_imag(value): + if hasattr(value, "is_complex") and value.is_complex(): + return float(value.real.item()), float(value.imag.item()) + scalar = float(value.item()) + return scalar, 0.0 + + +def _collect_samples(case_id, expected, triton, limit): + rows = [] + if expected is None or triton is None: + return rows + max_items = min(int(limit), int(expected.numel()), int(triton.numel())) + for pos in range(max_items): + exp_val = expected[pos] + tri_val = triton[pos] + exp_real, exp_imag = _to_real_imag(exp_val) + tri_real, tri_imag = _to_real_imag(tri_val) + abs_error = float(torch.abs(tri_val - exp_val).item()) + rows.append( + { + "case_id": case_id, + "pos": pos, + "expected_real": exp_real, + "expected_imag": exp_imag, + "triton_real": tri_real, + "triton_imag": tri_imag, + "abs_error": abs_error, + } + ) + return rows + + +def _ensure_parent_dir(path): + parent = os.path.dirname(os.path.abspath(path)) + if parent: + os.makedirs(parent, exist_ok=True) + + +def _write_csv(path, rows, fieldnames): + _ensure_parent_dir(path) + with open(path, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({k: ("" if v is None else v) for k, v in row.items()}) + + +def _print_header(): + print("-" * 196) + print( + f"{'ValueReq':>10} {'ValueEff':>10} {'Index':>6} {'Dense':>10} {'NNZ':>10} " + f"{'Reset':>6} {'FB':>4} {'IFB':>4} {'PT(ms)':>10} {'FS(ms)':>10} {'CS(ms)':>10} " + f"{'FS/PT':>8} {'FS/CS':>8} {'Status':>6} {'Err(FS)':>12} {'Err(CS)':>12}" + ) + print("-" * 196) + + +def _print_row(row): + print( + f"{row['value_dtype_req']:>10} {row['value_dtype_compute']:>10} {row['index_dtype']:>6} " + f"{row['dense_size']:>10,d} {row['nnz']:>10,d} " + f"{str(row['reset_output']):>6} {str(row['fallback_applied']):>4} " + f"{str(row['index_fallback_applied']):>4} " + f"{_fmt_ms(row['pytorch_ms']):>10} {_fmt_ms(row['triton_ms']):>10} {_fmt_ms(row['cusparse_ms']):>10} " + f"{_fmt_speedup(row['triton_speedup_vs_pytorch']):>8} {_fmt_speedup(row['triton_speedup_vs_cusparse']):>8} " + f"{row['status']:>6} {_fmt_err(row['triton_max_error']):>12} {_fmt_err(row['cusparse_max_error']):>12}" + ) + + +def run_cli(args): + if not torch.cuda.is_available(): + print("CUDA is not available. Please run on a GPU-enabled system.") + return + + value_dtype_tokens = _parse_value_dtypes(args.value_dtypes) + index_dtype_pairs = _parse_index_dtypes(args.index_dtypes) + cases = _parse_cases(args.cases) + reset_modes = _iter_reset_output_modes(args.reset_output) + unique_indices = _parse_bool_token(args.unique_indices, "unique-indices") + run_cusparse = not args.no_cusparse + + print("=" * 180) + print("FLAGSPARSE SCATTER BENCHMARK/VALIDATION") + print("=" * 180) + print(f"GPU: {torch.cuda.get_device_name(0)}") + print( + f"Warmup: {args.warmup} | Iterations: {args.iters} | " + f"dtype_policy: {args.dtype_policy} | index_fallback_policy: {args.index_fallback_policy} | " + f"unique_indices: {unique_indices}" + ) + print() + _print_header() + + summary_rows = [] + sample_rows = [] + total_cases = 0 + failed_cases = 0 + + for value_dtype in value_dtype_tokens: + for index_name, index_dtype in index_dtype_pairs: + for reset_output in reset_modes: + for dense_size, nnz in cases: + total_cases += 1 + case_id = ( + f"{value_dtype}|{index_name}|dense={dense_size}|nnz={nnz}|" + f"reset={str(reset_output).lower()}|unique={str(unique_indices).lower()}|" + f"policy={args.dtype_policy}|ifb_policy={args.index_fallback_policy}" + ) + try: + result = ast.benchmark_scatter_case( + dense_size=dense_size, + nnz=nnz, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=args.warmup, + iters=args.iters, + run_cusparse=run_cusparse, + unique_indices=unique_indices, + reset_output=reset_output, + dtype_policy=args.dtype_policy, + index_fallback_policy=args.index_fallback_policy, + ) + perf = result["performance"] + verify = result["verification"] + params = result["parameters"] + backend = result["backend_status"] + status = _status_from_result(verify) + if status != "PASS": + failed_cases += 1 + row = { + "case_id": case_id, + "gpu": torch.cuda.get_device_name(0), + "value_dtype_req": params.get("value_dtype"), + "value_dtype_compute": str(params.get("effective_value_dtype")), + "index_dtype": str(params.get("index_dtype")).replace("torch.", ""), + "dense_size": int(params.get("dense_size")), + "nnz": int(params.get("nnz")), + "unique_indices": bool(params.get("unique_indices")), + "reset_output": bool(params.get("reset_output")), + "dtype_policy": params.get("dtype_policy"), + "fallback_applied": bool(params.get("fallback_applied")), + "index_fallback_applied": bool( + backend.get("index_fallback_applied") + ), + "triton_ms": perf.get("triton_ms"), + "pytorch_ms": perf.get("pytorch_ms"), + "cusparse_ms": perf.get("cusparse_ms"), + "triton_speedup_vs_pytorch": perf.get("triton_speedup_vs_pytorch"), + "triton_speedup_vs_cusparse": perf.get("triton_speedup_vs_cusparse"), + "triton_match_pytorch": verify.get("triton_match_pytorch"), + "cusparse_match_pytorch": verify.get("cusparse_match_pytorch"), + "triton_max_error": verify.get("triton_max_error"), + "cusparse_max_error": verify.get("cusparse_max_error"), + "cusparse_unavailable_reason": backend.get("cusparse_unavailable_reason"), + "fallback_reason": backend.get("fallback_reason"), + "index_fallback_reason": backend.get("index_fallback_reason"), + "status": status, + } + summary_rows.append(row) + _print_row(row) + + if args.csv_samples: + sample_rows.extend( + _collect_samples( + case_id, + result["samples"].get("pytorch"), + result["samples"].get("triton"), + args.sample_limit, + ) + ) + except Exception as exc: + failed_cases += 1 + row = { + "case_id": case_id, + "gpu": torch.cuda.get_device_name(0), + "value_dtype_req": value_dtype, + "value_dtype_compute": "N/A", + "index_dtype": index_name, + "dense_size": dense_size, + "nnz": nnz, + "unique_indices": unique_indices, + "reset_output": reset_output, + "dtype_policy": args.dtype_policy, + "fallback_applied": False, + "index_fallback_applied": False, + "triton_ms": None, + "pytorch_ms": None, + "cusparse_ms": None, + "triton_speedup_vs_pytorch": None, + "triton_speedup_vs_cusparse": None, + "triton_match_pytorch": None, + "cusparse_match_pytorch": None, + "triton_max_error": None, + "cusparse_max_error": None, + "cusparse_unavailable_reason": str(exc), + "fallback_reason": None, + "index_fallback_reason": str(exc), + "status": "ERROR", + } + summary_rows.append(row) + _print_row(row) + + print("-" * 196) + print(f"Total cases: {total_cases}") + print(f"Failed cases: {failed_cases}") + print(f"Passed cases: {total_cases - failed_cases}") + + if args.csv_summary: + summary_fields = [ + "case_id", + "gpu", + "value_dtype_req", + "value_dtype_compute", + "index_dtype", + "dense_size", + "nnz", + "unique_indices", + "reset_output", + "dtype_policy", + "fallback_applied", + "index_fallback_applied", + "triton_ms", + "pytorch_ms", + "cusparse_ms", + "triton_speedup_vs_pytorch", + "triton_speedup_vs_cusparse", + "triton_match_pytorch", + "cusparse_match_pytorch", + "triton_max_error", + "cusparse_max_error", + "cusparse_unavailable_reason", + "fallback_reason", + "index_fallback_reason", + "status", + ] + _write_csv(args.csv_summary, summary_rows, summary_fields) + print(f"Wrote summary CSV: {args.csv_summary}") + + if args.csv_samples: + sample_fields = [ + "case_id", + "pos", + "expected_real", + "expected_imag", + "triton_real", + "triton_imag", + "abs_error", + ] + _write_csv(args.csv_samples, sample_rows, sample_fields) + print(f"Wrote samples CSV: {args.csv_samples}") + + +def build_parser(): + parser = argparse.ArgumentParser( + description="Scatter benchmark/validation with dtype fallback and CSV export." + ) + parser.add_argument("--value-dtypes", default=DEFAULT_VALUE_DTYPES) + parser.add_argument("--index-dtypes", default=DEFAULT_INDEX_DTYPES) + parser.add_argument( + "--cases", + default=",".join(f"{dense}:{nnz}" for dense, nnz in DEFAULT_CASES), + help="Comma-separated dense:nnz pairs, e.g. 32768:1024,131072:4096", + ) + parser.add_argument("--warmup", type=int, default=WARMUP) + parser.add_argument("--iters", type=int, default=ITERS) + parser.add_argument("--no-cusparse", action="store_true") + parser.add_argument("--unique-indices", default="true") + parser.add_argument("--reset-output", choices=["true", "false", "both"], default="true") + parser.add_argument("--dtype-policy", choices=["auto", "strict"], default="auto") + parser.add_argument("--index-fallback-policy", choices=["auto", "strict"], default="auto") + parser.add_argument("--csv-summary", default=None) + parser.add_argument("--csv-samples", default=None) + parser.add_argument("--sample-limit", type=int, default=32) + return parser + + +if __name__ == "__main__": + cli_parser = build_parser() + run_cli(cli_parser.parse_args()) diff --git a/tests/test_sddmm.py b/tests/test_sddmm.py new file mode 100644 index 0000000..4aac47c --- /dev/null +++ b/tests/test_sddmm.py @@ -0,0 +1,692 @@ +""" +SDDMM tests: load SuiteSparse .mtx as CSR pattern and benchmark +out = alpha * dot(X[row], Y[col]) + beta * in. +CuPy baseline uses sampled-dot on CSR pattern (not dense X@Y^T). + +acc_mode notes: +- acc_mode=f32 keeps the native float32 accumulate path for float32 inputs. +- acc_mode=f64 upgrades only the internal accumulation of float32 inputs to + float64 while still returning float32 outputs. +- float64 inputs always keep the existing float64 route; acc_mode only affects + float32 runs in this test harness. +""" + +import argparse +import csv +import glob +import os +import sys +import time +from pathlib import Path + +import torch + +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) +_TESTS_DIR = Path(__file__).resolve().parent +if str(_TESTS_DIR) not in sys.path: + sys.path.insert(0, str(_TESTS_DIR)) + +import flagsparse as ast +import flagsparse.sparse_operations.sddmm_csr as ast_ops +from test_spmm import load_mtx_to_csr_torch + +VALUE_DTYPES = [torch.float32, torch.float64] +INDEX_DTYPES = [torch.int32] +WARMUP = 5 +ITERS = 20 +DEFAULT_K = 64 +BASELINE_ATOL = 1e-4 +BASELINE_RTOL = 1e-2 +ACC64_ATOL = 1e-6 +ACC64_RTOL = 1e-5 + + +def _dtype_name(dtype): + return str(dtype).replace("torch.", "") + + +def _fmt_ms(value): + return "N/A" if value is None else f"{value:.4f}" + + +def _fmt_speedup(other_ms, triton_ms): + if other_ms is None or triton_ms is None or triton_ms <= 0: + return "N/A" + return f"{other_ms / triton_ms:.2f}x" + + +def _fmt_err(value): + return "N/A" if value is None else f"{value:.2e}" + + +def _fmt_check(value): + if value is None: + return "N/A" + return "PASS" if value else "FAIL" + + +def _status_label(value): + if isinstance(value, str): + return value + if value is None: + return "N/A" + return "PASS" if value else "FAIL" + + +def _is_resource_error(message): + text = str(message).lower() + resource_tokens = ( + "out of memory", + "cudaerroroutofmemory", + "cuda error out of memory", + "insufficient resources", + "resource exhausted", + "memoryerror", + "cublas_status_alloc_failed", + "cusparse_status_insufficient_resources", + ) + return any(token in text for token in resource_tokens) + + +def _cupy_sampled_dot_chunked(x_cp, y_cp, row_ids_cp, col_ids_cp, chunk_nnz): + nnz = int(row_ids_cp.size) + out_cp = ast_ops.cp.empty((nnz,), dtype=x_cp.dtype) + for start in range(0, nnz, chunk_nnz): + end = min(nnz, start + chunk_nnz) + rows = row_ids_cp[start:end] + cols = col_ids_cp[start:end] + out_cp[start:end] = ast_ops.cp.sum(x_cp[rows] * y_cp[cols], axis=1) + return out_cp + + +def _benchmark_cupy_sampled_reference(indices, indptr, x, y, data_in, alpha, beta, warmup, iters): + n_rows = int(indptr.numel()) - 1 + row_ids = torch.repeat_interleave( + torch.arange(n_rows, dtype=torch.int64, device=x.device), + indptr.to(torch.int64)[1:] - indptr.to(torch.int64)[:-1], + ) + x_cp = ast_ops._cupy_from_torch(x) + y_cp = ast_ops._cupy_from_torch(y) + row_ids_cp = ast_ops._cupy_from_torch(row_ids) + col_ids_cp = ast_ops._cupy_from_torch(indices.to(torch.int64)) + nnz = max(1, int(indices.numel())) + chunk_nnz = min(262144, nnz) + sampled_cp, cupy_ms = ast_ops._benchmark_cuda_op( + lambda: _cupy_sampled_dot_chunked(x_cp, y_cp, row_ids_cp, col_ids_cp, chunk_nnz), + warmup=warmup, + iters=iters, + ) + sampled = ast_ops._torch_from_cupy(sampled_cp) + sampled = sampled * alpha + if beta != 0.0: + sampled = sampled + beta * data_in + return sampled, cupy_ms + + +def _normalize_csv_path(csv_path): + csv_path = str(csv_path) + if not csv_path.lower().endswith(".csv"): + csv_path = f"{csv_path}.csv" + parent = os.path.dirname(os.path.abspath(csv_path)) + if parent: + os.makedirs(parent, exist_ok=True) + return csv_path + + +def _resolve_tolerance(value_dtype, acc_mode): + if value_dtype == torch.float32: + if acc_mode == "f64": + return ACC64_ATOL, ACC64_RTOL + return BASELINE_ATOL, BASELINE_RTOL + return ast_ops._tolerance_for_dtype(value_dtype) + + +def _scaled_allclose_error(candidate, reference, atol, rtol): + if candidate.numel() == 0: + return 0.0 + diff = torch.abs(candidate - reference) + denom = atol + rtol * torch.abs(reference) + return float(torch.max(diff / denom).item()) + + +def _benchmark_reference_sddmm(data, indices, indptr, x, y, alpha, beta, value_dtype, warmup, iters): + indptr64 = indptr.to(torch.int64) + if value_dtype == torch.float32: + x_ref = x.to(torch.float64) + y_ref = y.to(torch.float64) + data_ref = data.to(torch.float64) if data is not None else None + + op = lambda: ast_ops._sddmm_reference(indices, indptr64, x_ref, y_ref, data_ref, alpha, beta).to(torch.float32) + else: + op = lambda: ast_ops._sddmm_reference(indices, indptr64, x, y, data, alpha, beta) + ref_values, ref_ms = ast_ops._benchmark_cuda_op(op, warmup=warmup, iters=iters) + return ref_values, ref_ms + + +def _benchmark_triton_sddmm(data, indices, indptr, shape, x, y, alpha, beta, warmup, iters, acc_mode): + torch.cuda.synchronize() + t_prepare0 = time.perf_counter() + prepared = ast.prepare_sddmm_csr(indices, indptr, shape, k_hint=int(x.shape[1])) + torch.cuda.synchronize() + prepare_ms = (time.perf_counter() - t_prepare0) * 1000.0 + + torch.cuda.synchronize() + t_first0 = time.perf_counter() + if x.dtype == torch.float32 and acc_mode == "f64": + _ = ast_ops._run_sddmm_prepared( + prepared, + x.contiguous(), + y.contiguous(), + data.contiguous() if data is not None else None, + alpha, + beta, + out=None, + allow_fallback=False, + variant="acc64", + )[0] + else: + _ = ast.flagsparse_sddmm_csr(data=data, x=x, y=y, alpha=alpha, beta=beta, prepared=prepared) + torch.cuda.synchronize() + first_call_ms = (time.perf_counter() - t_first0) * 1000.0 + + if x.dtype == torch.float32 and acc_mode == "f64": + op = lambda: ast_ops._run_sddmm_prepared( + prepared, + x.contiguous(), + y.contiguous(), + data.contiguous() if data is not None else None, + alpha, + beta, + out=None, + allow_fallback=False, + variant="acc64", + )[0] + triton_values, triton_ms = ast_ops._benchmark_cuda_op(op, warmup=warmup, iters=iters) + _, meta = ast_ops._run_sddmm_prepared( + prepared, + x.contiguous(), + y.contiguous(), + data.contiguous() if data is not None else None, + alpha, + beta, + out=None, + allow_fallback=False, + variant="acc64", + ) + else: + triton_values, triton_ms = ast_ops._benchmark_cuda_op( + lambda: ast.flagsparse_sddmm_csr(data=data, x=x, y=y, alpha=alpha, beta=beta, prepared=prepared), + warmup=warmup, + iters=iters, + ) + _, meta = ast.flagsparse_sddmm_csr( + data=data, x=x, y=y, alpha=alpha, beta=beta, prepared=prepared, return_meta=True + ) + meta["prepare_ms"] = prepare_ms + return triton_values, triton_ms, first_call_ms, meta + + +def run_one_mtx( + mtx_path, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=WARMUP, + iters=ITERS, + k_dim=DEFAULT_K, + alpha=1.0, + beta=0.0, + run_cusparse=True, + acc_mode="f32", +): + device = torch.device("cuda") + _pattern_values, indices, indptr, shape = load_mtx_to_csr_torch(mtx_path, dtype=value_dtype, device=device) + indices = indices.to(index_dtype) + n_rows, n_cols = shape + nnz = int(indices.numel()) + data_in = torch.randn(nnz, dtype=value_dtype, device=device) + x = torch.randn((n_rows, k_dim), dtype=value_dtype, device=device) + y = torch.randn((n_cols, k_dim), dtype=value_dtype, device=device) + + result = { + "path": mtx_path, + "shape": shape, + "nnz": nnz, + "nnz_pattern": nnz, + "k": int(k_dim), + "alpha": float(alpha), + "beta": float(beta), + "error": None, + "triton_ms": None, + "triton_first_call_ms": None, + "prepare_ms": None, + "pytorch_ms": None, + "cupy_ms": None, + "cusparse_ms": None, + "err_pt": None, + "err_cu": None, + "triton_ok_pt": None, + "triton_ok_cu": None, + "cu_status": "REF_UNAVAILABLE", + "cu_reason": None, + "cusparse_reason": None, + "triton_started": False, + "cu_started": False, + "fallback_used": False, + "status": "UNKNOWN", + } + + triton_values = None + try: + result["triton_started"] = True + triton_values, triton_ms, triton_first_ms, meta = _benchmark_triton_sddmm( + data_in, indices, indptr, shape, x, y, alpha, beta, warmup, iters, acc_mode + ) + result["triton_ms"] = triton_ms + result["triton_first_call_ms"] = triton_first_ms + result["prepare_ms"] = meta.get("prepare_ms") + result["fallback_used"] = bool(meta.get("fallback_used", False)) + except Exception as exc: + result["error"] = f"triton: {exc}" + + try: + ref, result["pytorch_ms"] = _benchmark_reference_sddmm( + data_in, + indices, + indptr, + x, + y, + alpha, + beta, + value_dtype, + warmup, + iters, + ) + except Exception as exc: + result["error"] = str(exc) if result["error"] is None else f"{result['error']}; ref: {exc}" + result["status"] = "REF_FAIL" + return result + + if triton_values is not None: + atol, rtol = _resolve_tolerance(value_dtype, acc_mode) + result["triton_ok_pt"] = bool(torch.allclose(triton_values, ref, atol=atol, rtol=rtol)) + result["err_pt"] = _scaled_allclose_error(triton_values, ref, atol, rtol) + else: + result["triton_ok_pt"] = False + + if run_cusparse: + if ast_ops.cp is None: + result["cu_status"] = "PERF_ONLY" + result["cu_reason"] = "CuPy is not available" + else: + try: + result["cu_started"] = True + _cu_vals, cupy_ms = _benchmark_cupy_sampled_reference( + indices=indices, + indptr=indptr, + x=x, + y=y, + data_in=data_in, + alpha=alpha, + beta=beta, + warmup=warmup, + iters=iters, + ) + result["cupy_ms"] = cupy_ms + result["cusparse_ms"] = cupy_ms + result["cu_status"] = "PERF_ONLY" + except Exception as exc: + result["cu_status"] = "PERF_RESOURCE" if _is_resource_error(exc) else "PERF_UNAVAILABLE" + result["cu_reason"] = str(exc) + else: + result["cu_status"] = "PERF_ONLY" + result["cu_reason"] = "CuPy reference is disabled by CLI" + + result["cusparse_reason"] = result["cu_reason"] + result["status"] = "PASS" if result["triton_ok_pt"] else "FAIL" + return result + + +def run_mtx_batch( + mtx_paths, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=WARMUP, + iters=ITERS, + k_dim=DEFAULT_K, + alpha=1.0, + beta=0.0, + run_cusparse=True, + on_result=None, + acc_mode="f32", +): + results = [] + for path in mtx_paths: + entry = run_one_mtx( + path, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + k_dim=k_dim, + alpha=alpha, + beta=beta, + run_cusparse=run_cusparse, + acc_mode=acc_mode, + ) + results.append(entry) + if on_result is not None: + on_result(entry) + return results + + +def _print_sddmm_mtx_header(value_dtype, index_dtype, k_dim, alpha, beta, acc_mode): + print(f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)}") + print( + "Formats: FlagSparse=CSR SDDMM, CuPy sampled-dot performance baseline (not cuSPARSE API), " + "PyTorch correctness reference." + ) + print( + f"Equation: out = alpha*dot(x[row], y[col]) + beta*in | K={k_dim} alpha={alpha} beta={beta} acc_mode={acc_mode}" + ) + print("-" * 196) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} {'K':>6} " + f"{'FlagSparse(ms)':>14} {'CuPy(ms)':>11} {'PyTorch(ms)':>11} " + f"{'FS/CU':>7} {'FS/PT':>7} {'PT':>6} {'CU_Status':>12} {'Err(PT)':>10} {'Err(CU)':>10} {'Prep(ms)':>9}" + ) + print("-" * 196) + + +def _print_sddmm_mtx_row(entry): + name = os.path.basename(entry["path"])[:27] + n_rows, n_cols = entry["shape"] + cupy_ms = entry.get("cupy_ms") + if cupy_ms is None: + cupy_ms = entry.get("cusparse_ms") + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {entry['nnz_pattern']:>10} {entry['k']:>6} " + f"{_fmt_ms(entry.get('triton_ms')):>14} {_fmt_ms(cupy_ms):>11} {_fmt_ms(entry.get('pytorch_ms')):>11} " + f"{_fmt_speedup(cupy_ms, entry.get('triton_ms')):>7} {_fmt_speedup(entry.get('pytorch_ms'), entry.get('triton_ms')):>7} " + f"{_fmt_check(entry.get('triton_ok_pt')):>6} {_status_label(entry.get('cu_status')):>12} " + f"{_fmt_err(entry.get('err_pt')):>10} {_fmt_err(entry.get('err_cu')):>10} {_fmt_ms(entry.get('prepare_ms')):>9}" + ) + err = entry.get("error") + cu_reason = entry.get("cu_reason") + if err: + msg = str(err).replace("\n", " ") + if len(msg) > 220: + msg = msg[:217] + "..." + print(f" NOTE: {msg}") + if cu_reason: + msg = str(cu_reason).replace("\n", " ") + if len(msg) > 220: + msg = msg[:217] + "..." + print(f" CU_NOTE: {msg}") + + +def print_mtx_results(results, value_dtype, index_dtype, k_dim, alpha, beta, acc_mode): + _print_sddmm_mtx_header(value_dtype, index_dtype, k_dim, alpha, beta, acc_mode) + for entry in results: + _print_sddmm_mtx_row(entry) + print("-" * 196) + + +def run_all_dtypes_export_csv( + paths, + csv_path, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=WARMUP, + iters=ITERS, + k_dim=DEFAULT_K, + alpha=1.0, + beta=0.0, + run_cusparse=True, + acc_mode="f32", +): + csv_path = _normalize_csv_path(csv_path) + rows = [] + print("=" * 164) + _print_sddmm_mtx_header(value_dtype, index_dtype, k_dim, alpha, beta, acc_mode) + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + k_dim=k_dim, + alpha=alpha, + beta=beta, + run_cusparse=run_cusparse, + on_result=_print_sddmm_mtx_row, + acc_mode=acc_mode, + ) + print("-" * 196) + for entry in results: + n_rows, n_cols = entry["shape"] + rows.append( + { + "matrix": os.path.basename(entry["path"]), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": entry["nnz"], + "cupy_ms": entry.get("cupy_ms"), + "triton_ms": entry.get("triton_ms"), + "cusparse_ms": entry.get("cusparse_ms"), + "pytorch_ms": entry.get("pytorch_ms"), + "pt_status": _status_label(entry.get("triton_ok_pt")), + "cu_status": _status_label(entry.get("cu_status")), + "status": entry.get("status"), + "err_pt": entry.get("err_pt"), + "err_cu": entry.get("err_cu"), + "error": entry.get("error"), + "cu_reason": entry.get("cu_reason"), + "triton_started": entry.get("triton_started"), + "cu_started": entry.get("cu_started"), + "fallback_used": entry.get("fallback_used"), + "nnz_pattern": entry.get("nnz_pattern"), + "k": entry.get("k"), + "alpha": entry.get("alpha"), + "beta": entry.get("beta"), + "prepare_ms": entry.get("prepare_ms"), + } + ) + fieldnames = [ + "matrix", "value_dtype", "index_dtype", "n_rows", "n_cols", "nnz", + "triton_ms", "cupy_ms", "cusparse_ms", "pytorch_ms", + "pt_status", "cu_status", "status", "err_pt", "err_cu", "error", + "cu_reason", "triton_started", "cu_started", "fallback_used", + "nnz_pattern", "k", "alpha", "beta", "prepare_ms", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({key: ("" if value is None else value) for key, value in row.items()}) + print(f"Wrote {len(rows)} rows to {csv_path}") + + +def run_api_validation_checks(): + if not torch.cuda.is_available(): + print("API checks skipped: CUDA is not available.") + return 0 + device = torch.device("cuda") + indices = torch.tensor([0, 1, 1], dtype=torch.int32, device=device) + indptr = torch.tensor([0, 2, 3], dtype=torch.int64, device=device) + shape = (2, 2) + x = torch.randn((2, 8), dtype=torch.float32, device=device) + y = torch.randn((2, 8), dtype=torch.float32, device=device) + data = torch.randn(3, dtype=torch.float32, device=device) + + negative_cases = [ + ("indices must int32", lambda: ast.flagsparse_sddmm_csr(indices=indices.to(torch.int64), indptr=indptr, x=x, y=y, shape=shape), TypeError), + ("x/y K mismatch", lambda: ast.flagsparse_sddmm_csr(indices=indices, indptr=indptr, x=x, y=y[:, :4], shape=shape), ValueError), + ("data length mismatch", lambda: ast.flagsparse_sddmm_csr(data=torch.randn(2, dtype=torch.float32, device=device), indices=indices, indptr=indptr, x=x, y=y, shape=shape), ValueError), + ("beta needs data", lambda: ast.flagsparse_sddmm_csr(indices=indices, indptr=indptr, x=x, y=y, shape=shape, beta=0.5), ValueError), + ( + "K=0 out shape mismatch", + lambda: ast.flagsparse_sddmm_csr( + data=data, + indices=indices, + indptr=indptr, + x=x[:, :0], + y=y[:, :0], + shape=shape, + out=torch.empty(2, dtype=torch.float32, device=device), + ), + ValueError, + ), + ] + failed = 0 + print("-" * 96) + print("API validation checks (SDDMM)") + print("-" * 96) + for name, fn, exc_type in negative_cases: + try: + fn() + print(f"FAIL {name:<32} expected {exc_type.__name__}") + failed += 1 + except exc_type: + print(f"PASS {name:<32} raised {exc_type.__name__}") + except Exception as exc: + print(f"FAIL {name:<32} raised {type(exc).__name__}: {exc}") + failed += 1 + + try: + out = ast.flagsparse_sddmm_csr(data=data, indices=indices, indptr=indptr, x=x, y=y, shape=shape, alpha=1.25, beta=0.5) + if out.shape != (3,): + raise AssertionError("unexpected output shape") + print("PASS positive path returned correct output shape") + except Exception as exc: + print(f"FAIL positive path raised {type(exc).__name__}: {exc}") + failed += 1 + print("-" * 96) + return failed + + +def _expand_mtx_paths(raw_paths): + paths = [] + for p in raw_paths: + if os.path.isfile(p) and p.lower().endswith(".mtx"): + paths.append(p) + elif os.path.isdir(p): + paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + seen = set() + uniq = [] + for path in paths: + ap = os.path.abspath(path) + if ap not in seen: + uniq.append(ap) + seen.add(ap) + return uniq + + +def main(): + parser = argparse.ArgumentParser(description="FlagSparse SDDMM CSR tests") + parser.add_argument("mtx", nargs="*", help=".mtx files or directories") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float64"]) + parser.add_argument("--index-dtype", type=str, default="int32", choices=["int32"]) + parser.add_argument( + "--acc_mode", + type=str, + default="f32", + choices=["f32", "f64"], + help="For float32 runs, choose native f32 accumulation or float64 accumulation.", + ) + parser.add_argument("--warmup", type=int, default=WARMUP) + parser.add_argument("--iters", type=int, default=ITERS) + parser.add_argument("--k", type=int, default=DEFAULT_K, help="Dense feature dimension K") + parser.add_argument("--alpha", type=float, default=1.0) + parser.add_argument("--beta", type=float, default=0.0) + parser.add_argument( + "--no-cupy-ref", + action="store_true", + help="Skip CuPy sampled-dot performance baseline", + ) + parser.add_argument( + "--no-cusparse", + action="store_true", + help="Deprecated alias of --no-cupy-ref", + ) + parser.add_argument("--csv", type=str, default=None, metavar="FILE") + parser.add_argument("--skip-api-checks", action="store_true") + args = parser.parse_args() + + if args.k < 0: + raise ValueError("--k must be non-negative") + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + + if not args.skip_api_checks: + failed = run_api_validation_checks() + if failed > 0: + raise SystemExit(1) + run_cupy_ref = not (args.no_cupy_ref or args.no_cusparse) + + value_dtype = torch.float32 if args.dtype == "float32" else torch.float64 + index_dtype = torch.int32 + paths = _expand_mtx_paths(args.mtx) + if not paths and not args.csv: + print("No .mtx files given. Use: python test_sddmm.py [file2.mtx ...] or ") + print("Or export the current dtype to CSV: python test_sddmm.py --csv results.csv") + return + + if args.csv is not None: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found. Specify files or a directory.") + return + csv_path = _normalize_csv_path(args.csv) + print("=" * 110) + print("FLAGSPARSE SDDMM - export to CSV") + print("=" * 110) + print( + f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)} | dtype: {args.dtype} | acc_mode: {args.acc_mode} | K: {args.k} | alpha: {args.alpha} | beta: {args.beta} | CSV: {csv_path}" + ) + run_all_dtypes_export_csv( + paths, + csv_path, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=args.warmup, + iters=args.iters, + k_dim=args.k, + alpha=args.alpha, + beta=args.beta, + run_cusparse=run_cupy_ref, + acc_mode=args.acc_mode, + ) + return + + print("=" * 150) + print("FLAGSPARSE SDDMM - SuiteSparse .mtx batch (CSR pattern-guided)") + print("=" * 150) + print(f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)}") + print( + f"dtype: {args.dtype} index_dtype: {args.index_dtype} acc_mode: {args.acc_mode} K: {args.k} alpha: {args.alpha} beta: {args.beta} warmup: {args.warmup} iters: {args.iters}" + ) + print() + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=args.warmup, + iters=args.iters, + k_dim=args.k, + alpha=args.alpha, + beta=args.beta, + run_cusparse=run_cupy_ref, + acc_mode=args.acc_mode, + ) + print_mtx_results(results, value_dtype, index_dtype, args.k, args.alpha, args.beta, args.acc_mode) + + +if __name__ == "__main__": + main() diff --git a/tests/test_spgemm.py b/tests/test_spgemm.py new file mode 100644 index 0000000..9cb3e61 --- /dev/null +++ b/tests/test_spgemm.py @@ -0,0 +1,1779 @@ +""" +SpGEMM tests: load SuiteSparse .mtx, run CSR SpGEMM(A@B), and report +error/performance in a SpMM-like table and CSV format. +""" + +import argparse +import csv +import gc +import glob +import os +import subprocess +import sys +import tempfile +import time +from pathlib import Path + +import torch + +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) +_TESTS_DIR = Path(__file__).resolve().parent +if str(_TESTS_DIR) not in sys.path: + sys.path.insert(0, str(_TESTS_DIR)) + +import flagsparse as ast +import flagsparse.sparse_operations.spgemm_csr as ast_ops +from test_spmm import load_mtx_to_csr_torch + +VALUE_DTYPES = [torch.float32, torch.float64] +INDEX_DTYPES = [torch.int32] +CSV_VALUE_DTYPES = [torch.float32, torch.float64] +CSV_INDEX_DTYPES = [torch.int32] +WARMUP = 10 +ITERS = 50 +DEFAULT_INPUT_MODE = "auto" +TARGET_TIMED_WINDOW_SECONDS = 8.0 +DEFAULT_REF_BLOCK_ROWS = 0 # auto +DEFAULT_COMPARE_BLOCK_ROWS = 2048 +DEFAULT_COMPARE_DEVICE = "cpu" + + +def _dtype_name(dtype): + return str(dtype).replace("torch.", "") + + +def _fmt_ms(value): + return "N/A" if value is None else f"{value:.4f}" + + +def _fmt_speedup(other_ms, triton_ms): + if other_ms is None or triton_ms is None or triton_ms <= 0: + return "N/A" + return f"{other_ms / triton_ms:.2f}x" + + +def _fmt_err(value): + return "N/A" if value is None else f"{value:.2e}" + + +def _fmt_check(value): + if value is None: + return "N/A" + return "PASS" if value else "FAIL" + + +def _status_label(value): + if value is None: + return "N/A" + return "PASS" if value else "FAIL" + + +def _append_error(current, message): + msg = str(message) + if not current: + return msg + return f"{current}; {msg}" + + +def _is_resource_error(message): + text = str(message).lower() + tokens = ( + "out of memory", + "cudaerroroutofmemory", + "cuda out of memory", + "insufficient resources", + "resource exhausted", + "cusparsespgemm_workestimation", + "cusparsespgemm_compute", + "cusparse_status_insufficient_resources", + "cublas_status_alloc_failed", + "cannot allocate memory", + ) + return any(tok in text for tok in tokens) + + +def _cleanup_reference_pools(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if hasattr(torch.cuda, "ipc_collect"): + torch.cuda.ipc_collect() + if ast_ops.cp is not None: + try: + ast_ops.cp.get_default_memory_pool().free_all_blocks() + except Exception: + pass + try: + ast_ops.cp.get_default_pinned_memory_pool().free_all_blocks() + except Exception: + pass + + +def _parse_ref_block_rows(raw_value): + if raw_value is None: + return DEFAULT_REF_BLOCK_ROWS + s = str(raw_value).strip().lower() + if s in ("auto", "0", ""): + return 0 + v = int(s) + if v <= 0: + raise ValueError("--ref-block-rows must be positive or 'auto'") + return v + + +def _candidate_block_rows(n_rows, requested): + if requested and requested > 0: + return [int(requested)] + candidates = [8192, 4096, 2048, 1024, 512, 256] + out = [] + for c in candidates: + if c < n_rows: + out.append(c) + if not out: + out = [max(1, n_rows)] + return out + + +def _slice_csr_rows(data, indices, indptr, shape, row_start, row_end): + ptr_start = int(indptr[row_start].item()) + ptr_end = int(indptr[row_end].item()) + sub_data = data[ptr_start:ptr_end] + sub_indices = indices[ptr_start:ptr_end] + sub_indptr = (indptr[row_start:row_end + 1] - ptr_start).to(torch.int64) + sub_shape = (int(row_end - row_start), int(shape[1])) + return sub_data, sub_indices, sub_indptr, sub_shape + + +def _concat_csr_row_blocks(blocks, n_rows, n_cols, device, data_dtype=torch.float32): + if not blocks: + return ( + torch.empty(0, dtype=data_dtype, device=device), + torch.empty(0, dtype=torch.int32, device=device), + torch.zeros(n_rows + 1, dtype=torch.int64, device=device), + (n_rows, n_cols), + ) + data_dtype = blocks[0][0].dtype + data_parts = [] + idx_parts = [] + indptr_parts = [torch.zeros(1, dtype=torch.int64, device=device)] + nnz_acc = 0 + row_acc = 0 + for data_b, idx_b, indptr_b, shape_b in blocks: + rows_b = int(shape_b[0]) + row_acc += rows_b + data_parts.append(data_b) + idx_parts.append(idx_b.to(torch.int32)) + if rows_b > 0: + indptr_parts.append(indptr_b[1:].to(torch.int64) + nnz_acc) + nnz_acc += int(data_b.numel()) + if row_acc != int(n_rows): + raise RuntimeError(f"blocked CSR concat row mismatch: got {row_acc}, expected {n_rows}") + data = torch.cat(data_parts) if data_parts else torch.empty(0, dtype=data_dtype, device=device) + indices = torch.cat(idx_parts) if idx_parts else torch.empty(0, dtype=torch.int32, device=device) + indptr = torch.cat(indptr_parts) + return data, indices, indptr, (int(n_rows), int(n_cols)) + + +def _allclose_error_ratio(actual, reference, atol, rtol): + if actual.numel() == 0: + return 0.0 + diff = torch.abs(actual - reference).to(torch.float64) + tol = (atol + rtol * torch.abs(reference)).to(torch.float64) + return float(torch.max(diff / tol).item()) + + +def _csr_sorted_pairs_block(data, indices, indptr, n_cols, row_start, row_end): + row_start = int(row_start) + row_end = int(row_end) + if row_end <= row_start: + return ( + torch.empty(0, dtype=torch.int64, device=data.device), + torch.empty(0, dtype=data.dtype, device=data.device), + ) + ptr = indptr[row_start:row_end + 1].to(torch.int64) + start = int(ptr[0].item()) + end = int(ptr[-1].item()) + if end <= start: + return ( + torch.empty(0, dtype=torch.int64, device=data.device), + torch.empty(0, dtype=data.dtype, device=data.device), + ) + row_counts = ptr[1:] - ptr[:-1] + rows = torch.repeat_interleave( + torch.arange(row_start, row_end, device=data.device, dtype=torch.int64), + row_counts, + ) + cols = indices[start:end].to(torch.int64) + vals = data[start:end] + keys = rows * max(1, int(n_cols)) + cols + order = torch.argsort(keys) + return keys[order], vals[order] + + +def _spgemm_compare_metrics(candidate, reference, value_dtype): + c_data, c_indices, c_indptr, c_shape = candidate + r_data, r_indices, r_indptr, r_shape = reference + if c_shape != r_shape: + return { + "pattern_ok": False, + "pass": False, + "err_ratio": float("inf"), + "max_abs_error": float("inf"), + "max_relative_error": float("inf"), + "reason": f"shape mismatch {c_shape} vs {r_shape}", + } + + c_nnz = int(c_indptr[-1].item()) if c_indptr.numel() > 0 else 0 + r_nnz = int(r_indptr[-1].item()) if r_indptr.numel() > 0 else 0 + if c_nnz != r_nnz: + return { + "pattern_ok": False, + "pass": False, + "err_ratio": float("inf"), + "max_abs_error": float("inf"), + "max_relative_error": float("inf"), + "reason": f"nnz mismatch {c_nnz} vs {r_nnz}", + } + if c_nnz == 0: + return { + "pattern_ok": True, + "pass": True, + "err_ratio": 0.0, + "max_abs_error": 0.0, + "max_relative_error": 0.0, + "reason": "ok", + } + + atol, rtol = ast_ops._tolerance_for_dtype(value_dtype) + n_rows = int(c_shape[0]) + compare_rows = max(1, int(DEFAULT_COMPARE_BLOCK_ROWS)) + err_ratio = 0.0 + max_abs = 0.0 + ref_max = 0.0 + row = 0 + while row < n_rows: + chunk_rows = min(compare_rows, n_rows - row) + while True: + try: + c_keys, c_vals = _csr_sorted_pairs_block( + c_data, c_indices, c_indptr, c_shape[1], row, row + chunk_rows + ) + r_keys, r_vals = _csr_sorted_pairs_block( + r_data, r_indices, r_indptr, r_shape[1], row, row + chunk_rows + ) + break + except Exception as exc: + if _is_resource_error(exc) and chunk_rows > 1: + chunk_rows = max(1, chunk_rows // 2) + continue + raise + + if c_keys.numel() != r_keys.numel() or not torch.equal(c_keys, r_keys): + return { + "pattern_ok": False, + "pass": False, + "err_ratio": float("inf"), + "max_abs_error": float("inf"), + "max_relative_error": float("inf"), + "reason": f"sparsity pattern mismatch in rows [{row}, {row + chunk_rows})", + } + if c_vals.numel() > 0: + err_ratio = max(err_ratio, _allclose_error_ratio(c_vals, r_vals, atol, rtol)) + abs_diff = torch.abs(c_vals - r_vals) + max_abs = max(max_abs, float(torch.max(abs_diff).item())) + ref_max = max(ref_max, float(torch.max(torch.abs(r_vals)).item())) + row += chunk_rows + + max_rel = 0.0 if ref_max == 0.0 else max_abs / ref_max + ok = (not torch.isnan(torch.tensor(err_ratio)).item()) and err_ratio <= 1.0 + return { + "pattern_ok": True, + "pass": bool(ok), + "err_ratio": err_ratio, + "max_abs_error": max_abs, + "max_relative_error": max_rel, + "reason": "ok" if ok else "value mismatch", + } + + +def _compare_spgemm_cpu(candidate, reference, value_dtype): + return _spgemm_compare_metrics(candidate, reference, value_dtype) + + +def _classify_reference_reason(*messages): + merged = " ".join(str(m).lower() for m in messages if m) + if not merged: + return "REF_UNAVAILABLE" + if ( + "out of memory" in merged + or "cudaerroroutofmemory" in merged + or "cuda out of memory" in merged + or "cannot allocate memory" in merged + or "cublas_status_alloc_failed" in merged + ): + return "REF_OOM" + if ( + "insufficient resources" in merged + or "cusparsespgemm_workestimation" in merged + or "cusparsespgemm_compute" in merged + or "cusparse_status_insufficient_resources" in merged + or "resource exhausted" in merged + or "resource" in merged + ): + return "REF_RESOURCE" + return "REF_UNAVAILABLE" + + +def _normalize_csv_path(csv_path): + csv_path = str(csv_path) + if not csv_path.lower().endswith(".csv"): + csv_path = f"{csv_path}.csv" + parent = os.path.dirname(os.path.abspath(csv_path)) + if parent: + os.makedirs(parent, exist_ok=True) + return csv_path + + +def _log_stage(path, stage, start_time): + elapsed = time.perf_counter() - start_time + print(f"[SpGEMM][{os.path.basename(path)}] {stage} (elapsed={elapsed:.2f}s)", flush=True) + + +def _resolve_input_mode(requested_mode, shape): + n_rows, n_cols = shape + if requested_mode == "a_equals_b": + if n_rows != n_cols: + raise ValueError( + f"input_mode=a_equals_b requires square matrix, got {n_rows}x{n_cols}" + ) + return "A_EQUALS_B" + if requested_mode == "a_at": + return "A_AT" + if requested_mode == "auto": + return "A_EQUALS_B" if n_rows == n_cols else "A_AT" + raise ValueError(f"unsupported input_mode: {requested_mode}") + + +def _build_spgemm_rhs(a_data, a_indices, a_indptr, a_shape, mode): + if mode == "A_EQUALS_B": + return a_data, a_indices, a_indptr, a_shape + if mode != "A_AT": + raise ValueError(f"unsupported resolved mode: {mode}") + a_t = ast_ops._to_torch_csr(a_data, a_indices, a_indptr, a_shape) + # CSR^T may materialize as CSC; convert through COO so downstream always receives CSR. + b_t = a_t.transpose(0, 1).to_sparse_coo().coalesce() + return ast_ops._torch_sparse_to_csr(b_t) + + +def _build_torch_spgemm_reference( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, +): + a_csr = torch.sparse_csr_tensor( + a_indptr.to(torch.int64), + a_indices.to(torch.int64), + a_data, + size=a_shape, + device=a_data.device, + ) + b_csr = torch.sparse_csr_tensor( + b_indptr.to(torch.int64), + b_indices.to(torch.int64), + b_data, + size=b_shape, + device=b_data.device, + ) + ref_format = "CSR" + try: + op = lambda: torch.sparse.mm(a_csr, b_csr) + ref_sparse = op() + except Exception: + ref_format = "COO" + a_coo = a_csr.to_sparse_coo().coalesce() + b_coo = b_csr.to_sparse_coo().coalesce() + op = lambda: torch.sparse.mm(a_coo, b_coo) + ref_sparse = op() + if ref_sparse.layout not in (torch.sparse_coo, torch.sparse_csr): + raise RuntimeError(f"Unexpected torch sparse.mm result layout: {ref_sparse.layout}") + return ast_ops._torch_sparse_to_csr(ref_sparse), ref_format, op + + +def _build_torch_spgemm_reference_blocked( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + block_rows, +): + n_rows = int(a_shape[0]) + + def _op(): + blocks = [] + formats = set() + for row_start in range(0, n_rows, block_rows): + row_end = min(row_start + block_rows, n_rows) + a_blk = _slice_csr_rows(a_data, a_indices, a_indptr, a_shape, row_start, row_end) + ref_blk, fmt_blk, _ = _build_torch_spgemm_reference( + a_blk[0], a_blk[1], a_blk[2], a_blk[3], + b_data, b_indices, b_indptr, b_shape, + ) + blocks.append(ref_blk) + formats.add(fmt_blk) + fmt = "BLOCKED_CSR" if formats == {"CSR"} else "BLOCKED_MIXED" + csr = _concat_csr_row_blocks( + blocks, + n_rows=n_rows, + n_cols=int(b_shape[1]), + device=a_data.device, + data_dtype=a_data.dtype, + ) + return csr, fmt + + csr_result, fmt_result = _op() + + def _bench_op(): + csr, _ = _op() + return csr + + return csr_result, fmt_result, _bench_op + + +def _build_cupy_spgemm_reference( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, +): + if ast_ops.cp is None or ast_ops.cpx_sparse is None: + raise RuntimeError("CuPy/cuSPARSE is not available") + a_cp = ast_ops.cpx_sparse.csr_matrix( + ( + ast_ops._cupy_from_torch(a_data), + ast_ops._cupy_from_torch(a_indices.to(torch.int64)), + ast_ops._cupy_from_torch(a_indptr.to(torch.int64)), + ), + shape=a_shape, + ) + b_cp = ast_ops.cpx_sparse.csr_matrix( + ( + ast_ops._cupy_from_torch(b_data), + ast_ops._cupy_from_torch(b_indices.to(torch.int64)), + ast_ops._cupy_from_torch(b_indptr.to(torch.int64)), + ), + shape=b_shape, + ) + + def _op(): + c_cp = a_cp @ b_cp + c_coo = c_cp.tocoo() + rows = ast_ops._torch_from_cupy(c_coo.row).to(torch.int64) + cols = ast_ops._torch_from_cupy(c_coo.col).to(torch.int64) + vals = ast_ops._torch_from_cupy(c_coo.data).to(a_data.dtype) + c_t = torch.sparse_coo_tensor( + torch.stack([rows, cols]), vals, (a_shape[0], b_shape[1]), device=a_data.device + ).coalesce() + return ast_ops._torch_sparse_to_csr(c_t) + + return _op(), "CSR", _op + + +def _build_cupy_spgemm_reference_blocked( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + block_rows, +): + n_rows = int(a_shape[0]) + + def _op(): + blocks = [] + for row_start in range(0, n_rows, block_rows): + row_end = min(row_start + block_rows, n_rows) + a_blk = _slice_csr_rows(a_data, a_indices, a_indptr, a_shape, row_start, row_end) + ref_blk, _, _ = _build_cupy_spgemm_reference( + a_blk[0], a_blk[1], a_blk[2], a_blk[3], + b_data, b_indices, b_indptr, b_shape, + ) + blocks.append(ref_blk) + return _concat_csr_row_blocks( + blocks, + n_rows=n_rows, + n_cols=int(b_shape[1]), + device=a_data.device, + data_dtype=a_data.dtype, + ) + + return _op(), "BLOCKED_CSR", _op + + +def _run_reference_worker_subprocess( + backend, + mtx_path, + value_dtype, + input_mode, + warmup, + iters, + blocked_retry, + block_rows, + ref_cleanup, +): + py = sys.executable + if not py: + return { + "success": False, + "reason": "isolated retry skipped: python executable is unavailable", + "fail_stage": "isolated", + "exec_mode": "isolated", + } + tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".pt") + tmp_path = tmp.name + tmp.close() + cmd = [ + py, + str(Path(__file__).resolve()), + "--_ref-worker", + backend, + "--_worker-mtx", + str(mtx_path), + "--_worker-output", + tmp_path, + "--dtype", + _dtype_name(value_dtype), + "--index-dtype", + "int32", + "--warmup", + str(int(warmup)), + "--iters", + str(int(iters)), + "--_worker-input-mode", + str(input_mode).lower(), + ] + if block_rows > 0: + cmd.extend(["--_worker-block-rows", str(int(block_rows))]) + if not blocked_retry: + cmd.append("--_worker-no-blocked") + if not ref_cleanup: + cmd.append("--_worker-no-cleanup") + proc = subprocess.run(cmd, capture_output=True, text=True) + payload = None + try: + if os.path.exists(tmp_path): + payload = torch.load(tmp_path, map_location="cpu") + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + if isinstance(payload, dict) and payload.get("success"): + payload["exec_mode"] = "isolated" + payload["retry_count"] = int(payload.get("retry_count", 0)) + return payload + stderr = proc.stderr.strip() if proc.stderr else "" + stdout = proc.stdout.strip() if proc.stdout else "" + reason = None + if isinstance(payload, dict): + reason = payload.get("reason") + if not reason: + reason = stderr or stdout or f"isolated worker failed with code {proc.returncode}" + return { + "success": False, + "reason": reason, + "fail_stage": "isolated", + "exec_mode": "isolated", + } + + +def _run_reference_with_retries( + backend, + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + warmup, + iters, + blocked_retry, + block_rows, + isolated_retry, + ref_cleanup, + mtx_path, + value_dtype, + input_mode, + result_device="gpu", +): + run_direct = _build_torch_spgemm_reference if backend == "torch" else _build_cupy_spgemm_reference + run_blocked = _build_torch_spgemm_reference_blocked if backend == "torch" else _build_cupy_spgemm_reference_blocked + attempted_modes = ["direct"] + + def _mark_mode(mode): + if mode not in attempted_modes: + attempted_modes.append(mode) + + def _finalize(): + state["attempted_modes"] = ">".join(attempted_modes) + return state + + state = { + "success": False, + "result": None, + "format": None, + "ms": None, + "exec_mode": "direct", + "retry_count": 0, + "peak_block_rows": None, + "reason": None, + "fail_stage": None, + "attempted_modes": "direct", + } + try: + ref_result, ref_format, ref_op = run_direct( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + ) + _, ref_ms = ast_ops._benchmark_cuda_op(ref_op, warmup=warmup, iters=iters) + compare_result = _convert_result_for_compare(ref_result, result_device, a_data.device) + ref_result = None + if result_device == "cpu" and ref_cleanup: + _cleanup_reference_pools() + state.update( + { + "success": True, + "result": compare_result, + "format": ref_format, + "ms": ref_ms, + "exec_mode": "direct", + } + ) + return _finalize() + except Exception as exc: + state["reason"] = str(exc) + state["fail_stage"] = "direct" + if ref_cleanup: + _cleanup_reference_pools() + + if blocked_retry and _is_resource_error(state["reason"]): + _mark_mode("blocked") + state["exec_mode"] = "blocked" + for br in _candidate_block_rows(int(a_shape[0]), block_rows): + try: + state["retry_count"] += 1 + ref_result, ref_format, ref_op = run_blocked( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + block_rows=int(br), + ) + _, ref_ms = ast_ops._benchmark_cuda_op(ref_op, warmup=warmup, iters=iters) + compare_result = _convert_result_for_compare(ref_result, result_device, a_data.device) + ref_result = None + if result_device == "cpu" and ref_cleanup: + _cleanup_reference_pools() + state.update( + { + "success": True, + "result": compare_result, + "format": ref_format, + "ms": ref_ms, + "exec_mode": "blocked", + "peak_block_rows": int(br), + "reason": None, + "fail_stage": None, + } + ) + return _finalize() + except Exception as blk_exc: + state["reason"] = str(blk_exc) + state["fail_stage"] = "blocked" + state["peak_block_rows"] = int(br) + if ref_cleanup: + _cleanup_reference_pools() + + if isolated_retry and _is_resource_error(state["reason"]): + _mark_mode("isolated") + state["exec_mode"] = "isolated" + state["retry_count"] += 1 + iso = _run_reference_worker_subprocess( + backend=backend, + mtx_path=mtx_path, + value_dtype=value_dtype, + input_mode=input_mode, + warmup=warmup, + iters=iters, + blocked_retry=blocked_retry, + block_rows=block_rows, + ref_cleanup=ref_cleanup, + ) + if iso.get("success"): + ref_payload = iso.get("result") + if not isinstance(ref_payload, (tuple, list)) or len(ref_payload) != 4: + state["reason"] = "isolated worker produced invalid CSR payload" + state["fail_stage"] = "isolated" + return _finalize() + state.update( + { + "success": True, + "result": _convert_result_for_compare( + ( + ref_payload[0], + ref_payload[1], + ref_payload[2], + tuple(ref_payload[3]), + ), + result_device, + a_data.device, + ), + "format": iso.get("format", "CSR"), + "ms": iso.get("ms"), + "exec_mode": "isolated", + "peak_block_rows": iso.get("peak_block_rows"), + "reason": None, + "fail_stage": None, + } + ) + return _finalize() + state["reason"] = iso.get("reason") + state["fail_stage"] = iso.get("fail_stage", "isolated") + return _finalize() + + +def _pick_effective_benchmark_loops(warmup, iters, first_call_ms, target_window_seconds): + warmup = max(0, int(warmup)) + iters = max(1, int(iters)) + per_call_s = max(float(first_call_ms) / 1000.0, 1e-4) + target_iters = max(1, int(float(target_window_seconds) / per_call_s)) + eff_iters = min(iters, target_iters) + warmup_cap = max(1, eff_iters // 2) + eff_warmup = min(warmup, warmup_cap) + return eff_warmup, eff_iters + + +def _benchmark_flagsparse_spgemm( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + warmup, + iters, + adaptive_loops, + target_window_seconds, + start_time, + mtx_path, +): + _log_stage(mtx_path, "prepare", start_time) + torch.cuda.synchronize() + t_prepare0 = time.perf_counter() + prepared = ast.prepare_spgemm_csr( + a_data, a_indices, a_indptr, a_shape, + b_data, b_indices, b_indptr, b_shape, + ) + torch.cuda.synchronize() + prepare_ms = (time.perf_counter() - t_prepare0) * 1000.0 + + _log_stage(mtx_path, "first-call", start_time) + torch.cuda.synchronize() + t_first0 = time.perf_counter() + first_result, first_meta = ast.flagsparse_spgemm_csr(prepared=prepared, return_meta=True) + torch.cuda.synchronize() + first_call_ms = (time.perf_counter() - t_first0) * 1000.0 + first_meta = dict(first_meta) + first_meta["prepare_ms"] = prepare_ms + + if adaptive_loops: + eff_warmup, eff_iters = _pick_effective_benchmark_loops( + warmup=warmup, + iters=iters, + first_call_ms=first_call_ms, + target_window_seconds=target_window_seconds, + ) + else: + eff_warmup = max(0, int(warmup)) + eff_iters = max(1, int(iters)) + first_meta["effective_warmup"] = eff_warmup + first_meta["effective_iters"] = eff_iters + + out_buffers = (first_result[0], first_result[1], first_result[2]) + _log_stage(mtx_path, f"timed-run warmup={eff_warmup} iters={eff_iters}", start_time) + triton_result, triton_ms = ast_ops._benchmark_cuda_op( + lambda: ast.flagsparse_spgemm_csr(prepared=prepared, out=out_buffers), + warmup=eff_warmup, + iters=eff_iters, + ) + return triton_result, triton_ms, first_call_ms, first_meta + + +def run_one_mtx( + mtx_path, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=WARMUP, + iters=ITERS, + run_cusparse=True, + input_mode=DEFAULT_INPUT_MODE, + adaptive_loops=False, + target_window_seconds=TARGET_TIMED_WINDOW_SECONDS, + ref_blocked_retry=True, + ref_block_rows=DEFAULT_REF_BLOCK_ROWS, + ref_isolated_retry=True, + ref_cleanup=True, + compare_device=DEFAULT_COMPARE_DEVICE, +): + case_start = time.perf_counter() + device = torch.device("cuda") + + _log_stage(mtx_path, "load", case_start) + a_data, a_indices, a_indptr, a_shape = load_mtx_to_csr_torch( + mtx_path, dtype=value_dtype, device=device + ) + a_indices = a_indices.to(index_dtype) + resolved_mode = _resolve_input_mode(input_mode, a_shape) + b_data, b_indices, b_indptr, b_shape = _build_spgemm_rhs( + a_data, a_indices, a_indptr, a_shape, resolved_mode + ) + b_indices = b_indices.to(index_dtype) + + nnz_a = int(a_data.numel()) + nnz_b = int(b_data.numel()) + print( + f"[SpGEMM][{os.path.basename(mtx_path)}] preflight mode={resolved_mode} " + f"shape_a={a_shape} shape_b={b_shape} nnz_a={nnz_a} nnz_b={nnz_b}", + flush=True, + ) + result = { + "path": mtx_path, + "shape": a_shape, + "shape_a": a_shape, + "shape_b": b_shape, + "nnz": nnz_a, + "nnz_a": nnz_a, + "nnz_b": nnz_b, + "nnz_c": None, + "input_mode": resolved_mode, + "error": None, + "triton_started": False, + "ref_started": False, + "ref_reason_code": None, + "triton_ms": None, + "triton_first_call_ms": None, + "prepare_ms": None, + "count_ms": None, + "fill_ms": None, + "bucket_nrows_short": None, + "bucket_nrows_medium": None, + "bucket_nrows_long": None, + "bucket_ms_short": None, + "bucket_ms_medium": None, + "bucket_ms_long": None, + "long_row_sliced_count": None, + "effective_warmup": None, + "effective_iters": None, + "pytorch_ms": None, + "cusparse_ms": None, + "err_pt": None, + "err_cu": None, + "max_abs_err_pt": None, + "max_rel_err_pt": None, + "max_abs_err_cu": None, + "max_rel_err_cu": None, + "triton_ok_pt": None, + "triton_ok_cu": None, + "pytorch_reason": None, + "cusparse_reason": None, + "pytorch_format": None, + "pt_exec_mode": None, + "cu_exec_mode": None, + "attempted_modes_pt": None, + "attempted_modes_cu": None, + "compare_status": "OK", + "compare_device": compare_device, + "pt_retry_count": 0, + "cu_retry_count": 0, + "ref_peak_block_rows": None, + "ref_fail_stage": None, + "status": "UNKNOWN", + } + + triton_result = None + triton_compare_result = None + _log_stage(mtx_path, "reference", case_start) + result["ref_started"] = True + pt_compared = False + cu_compared = False + pt_ref_success = False + cu_ref_success = False + pt_ref_result = None + cu_ref_result = None + ref_warmup = result["effective_warmup"] if result["effective_warmup"] is not None else warmup + ref_iters = result["effective_iters"] if result["effective_iters"] is not None else iters + ref_warmup = max(0, int(ref_warmup)) + ref_iters = max(1, int(ref_iters)) + + pt_ref = _run_reference_with_retries( + backend="torch", + a_data=a_data, + a_indices=a_indices, + a_indptr=a_indptr, + a_shape=a_shape, + b_data=b_data, + b_indices=b_indices, + b_indptr=b_indptr, + b_shape=b_shape, + warmup=ref_warmup, + iters=ref_iters, + blocked_retry=ref_blocked_retry, + block_rows=ref_block_rows, + isolated_retry=ref_isolated_retry, + ref_cleanup=ref_cleanup, + mtx_path=mtx_path, + value_dtype=value_dtype, + input_mode=input_mode, + result_device=compare_device, + ) + result["pt_exec_mode"] = pt_ref.get("exec_mode") + result["attempted_modes_pt"] = pt_ref.get("attempted_modes") + result["pt_retry_count"] = int(pt_ref.get("retry_count", 0)) + if pt_ref.get("peak_block_rows") is not None: + result["ref_peak_block_rows"] = int(pt_ref["peak_block_rows"]) + if pt_ref.get("success"): + pt_ref_success = True + pt_ref_result = pt_ref.get("result") + result["pytorch_format"] = pt_ref.get("format") + result["pytorch_ms"] = pt_ref.get("ms") + else: + result["pytorch_reason"] = pt_ref.get("reason") + result["ref_fail_stage"] = pt_ref.get("fail_stage") + result["error"] = _append_error(result["error"], f"pt_ref: {pt_ref.get('reason')}") + if ref_cleanup: + _cleanup_reference_pools() + + if run_cusparse: + cu_ref = _run_reference_with_retries( + backend="cupy", + a_data=a_data, + a_indices=a_indices, + a_indptr=a_indptr, + a_shape=a_shape, + b_data=b_data, + b_indices=b_indices, + b_indptr=b_indptr, + b_shape=b_shape, + warmup=ref_warmup, + iters=ref_iters, + blocked_retry=ref_blocked_retry, + block_rows=ref_block_rows, + isolated_retry=ref_isolated_retry, + ref_cleanup=ref_cleanup, + mtx_path=mtx_path, + value_dtype=value_dtype, + input_mode=input_mode, + result_device=compare_device, + ) + result["cu_exec_mode"] = cu_ref.get("exec_mode") + result["attempted_modes_cu"] = cu_ref.get("attempted_modes") + result["cu_retry_count"] = int(cu_ref.get("retry_count", 0)) + if cu_ref.get("peak_block_rows") is not None: + cur_peak = result.get("ref_peak_block_rows") + result["ref_peak_block_rows"] = int(cu_ref["peak_block_rows"]) if cur_peak is None else max(int(cur_peak), int(cu_ref["peak_block_rows"])) + if cu_ref.get("success"): + cu_ref_success = True + result["cusparse_ms"] = cu_ref.get("ms") + cu_ref_result = cu_ref.get("result") + else: + result["cusparse_reason"] = cu_ref.get("reason") + if result.get("ref_fail_stage") is None: + result["ref_fail_stage"] = cu_ref.get("fail_stage") + result["error"] = _append_error(result["error"], f"cu_ref: {cu_ref.get('reason')}") + else: + result["cusparse_reason"] = "CuPy/cuSPARSE reference is disabled" + result["cu_exec_mode"] = "disabled" + result["attempted_modes_cu"] = "disabled" + # Ensure reference-side temporary allocations are released before Triton run. + _cleanup_reference_pools() + + _log_stage(mtx_path, "triton", case_start) + try: + result["triton_started"] = True + triton_result, triton_ms, triton_first_ms, meta = _benchmark_flagsparse_spgemm( + a_data, + a_indices, + a_indptr, + a_shape, + b_data, + b_indices, + b_indptr, + b_shape, + warmup=warmup, + iters=iters, + adaptive_loops=adaptive_loops, + target_window_seconds=target_window_seconds, + start_time=case_start, + mtx_path=mtx_path, + ) + result["triton_ms"] = triton_ms + result["triton_first_call_ms"] = triton_first_ms + result["prepare_ms"] = meta.get("prepare_ms") + result["count_ms"] = meta.get("count_ms") + result["fill_ms"] = meta.get("fill_ms") + result["bucket_nrows_short"] = meta.get("bucket_nrows_short") + result["bucket_nrows_medium"] = meta.get("bucket_nrows_medium") + result["bucket_nrows_long"] = meta.get("bucket_nrows_long") + result["bucket_ms_short"] = meta.get("bucket_ms_short") + result["bucket_ms_medium"] = meta.get("bucket_ms_medium") + result["bucket_ms_long"] = meta.get("bucket_ms_long") + result["long_row_sliced_count"] = meta.get("long_row_sliced_count") + result["effective_warmup"] = meta.get("effective_warmup") + result["effective_iters"] = meta.get("effective_iters") + result["nnz_c"] = int(triton_result[0].numel()) if triton_result is not None else None + triton_compare_result = _convert_result_for_compare( + triton_result, + compare_device, + device=device, + ) + if compare_device == "cpu": + triton_result = None + _cleanup_reference_pools() + except Exception as exc: + result["error"] = _append_error(result["error"], f"triton: {exc}") + + _log_stage(mtx_path, "compare", case_start) + if triton_compare_result is not None and pt_ref_result is not None: + try: + compare_fn = _compare_spgemm_cpu if compare_device == "cpu" else _spgemm_compare_metrics + pt_metrics = compare_fn(triton_compare_result, pt_ref_result, value_dtype) + result["triton_ok_pt"] = pt_metrics["pass"] + result["err_pt"] = pt_metrics["err_ratio"] + result["max_abs_err_pt"] = pt_metrics["max_abs_error"] + result["max_rel_err_pt"] = pt_metrics["max_relative_error"] + if not pt_metrics["pattern_ok"]: + result["error"] = _append_error(result["error"], f"pt_ref: {pt_metrics['reason']}") + pt_compared = True + except Exception as cmp_exc: + cmp_msg = str(cmp_exc) + result["error"] = _append_error(result["error"], f"pt_compare: {cmp_msg}") + if _is_resource_error(cmp_msg): + result["compare_status"] = "COMPARE_OOM" + elif result.get("compare_status") == "OK": + result["compare_status"] = "COMPARE_FAIL" + finally: + if compare_device == "cpu": + pt_ref_result = None + if ref_cleanup: + _cleanup_reference_pools() + if triton_compare_result is not None and cu_ref_result is not None: + try: + compare_fn = _compare_spgemm_cpu if compare_device == "cpu" else _spgemm_compare_metrics + cu_metrics = compare_fn(triton_compare_result, cu_ref_result, value_dtype) + result["triton_ok_cu"] = cu_metrics["pass"] + result["err_cu"] = cu_metrics["err_ratio"] + result["max_abs_err_cu"] = cu_metrics["max_abs_error"] + result["max_rel_err_cu"] = cu_metrics["max_relative_error"] + if not cu_metrics["pattern_ok"]: + result["error"] = _append_error(result["error"], f"cu_ref: {cu_metrics['reason']}") + cu_compared = True + except Exception as cmp_exc: + cmp_msg = str(cmp_exc) + result["error"] = _append_error(result["error"], f"cu_compare: {cmp_msg}") + if _is_resource_error(cmp_msg): + result["compare_status"] = "COMPARE_OOM" + elif result.get("compare_status") == "OK": + result["compare_status"] = "COMPARE_FAIL" + finally: + if compare_device == "cpu": + cu_ref_result = None + if ref_cleanup: + _cleanup_reference_pools() + + if triton_compare_result is None: + result["status"] = "FAIL" + result["ref_reason_code"] = _classify_reference_reason( + result.get("pytorch_reason"), + result.get("cusparse_reason"), + result.get("error"), + ) + if ref_cleanup: + _cleanup_reference_pools() + return result + + if pt_compared or cu_compared: + result["status"] = "PASS" if (result["triton_ok_pt"] or result["triton_ok_cu"]) else "FAIL" + else: + had_ref_success = pt_ref_success or cu_ref_success + if had_ref_success and result.get("compare_status") != "OK": + result["status"] = "FAIL" + result["ref_reason_code"] = result.get("compare_status") + else: + ref_code = _classify_reference_reason( + result.get("pytorch_reason"), + result.get("cusparse_reason"), + ) + result["ref_reason_code"] = ref_code + result["status"] = ref_code + if ref_cleanup: + _cleanup_reference_pools() + return result + + +def run_mtx_batch( + mtx_paths, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=WARMUP, + iters=ITERS, + run_cusparse=True, + input_mode=DEFAULT_INPUT_MODE, + adaptive_loops=False, + target_window_seconds=TARGET_TIMED_WINDOW_SECONDS, + ref_blocked_retry=True, + ref_block_rows=DEFAULT_REF_BLOCK_ROWS, + ref_isolated_retry=True, + ref_cleanup=True, + compare_device=DEFAULT_COMPARE_DEVICE, + on_result=None, +): + results = [] + total = len(mtx_paths) + for idx, path in enumerate(mtx_paths, start=1): + print(f"[SpGEMM] ({idx}/{total}) {path}", flush=True) + entry = run_one_mtx( + path, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + input_mode=input_mode, + adaptive_loops=adaptive_loops, + target_window_seconds=target_window_seconds, + ref_blocked_retry=ref_blocked_retry, + ref_block_rows=ref_block_rows, + ref_isolated_retry=ref_isolated_retry, + ref_cleanup=ref_cleanup, + compare_device=compare_device, + ) + results.append(entry) + if on_result is not None: + on_result(entry) + return results + + +def _print_spgemm_mtx_header(value_dtype, index_dtype): + print(f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)}") + print("Formats: FlagSparse=CSR SpGEMM(A@B), cuSPARSE=CSR@CSR, PyTorch=sparse.mm.") + print("Err(PT/CU)=max(|diff|/(atol+rtol*|ref|)); MaxRel=max(|diff|)/max(|ref|).") + print("-" * 320) + print( + f"{'Matrix':<28} {'Mode':<10} {'A_rows':>7} {'A_cols':>7} {'B_cols':>7} {'NNZ_A':>10} {'NNZ_B':>10} {'NNZ_C':>10} " + f"{'FlagSparse(ms)':>14} {'cuSPARSE(ms)':>13} {'PyTorch(ms)':>11} " + f"{'FS/CU':>7} {'FS/PT':>7} {'PT':>6} {'CU':>6} {'Status':>13} {'RefCode':>14} " + f"{'Err(PT)':>10} {'Err(CU)':>10} {'MaxAbs(PT)':>12} {'MaxRel(PT)':>12} {'MaxAbs(CU)':>12} {'MaxRel(CU)':>12} " + f"{'Prep(ms)':>9} {'Count(ms)':>10} {'Fill(ms)':>9}" + ) + print("-" * 320) + + +def _print_spgemm_mtx_row(entry): + name = os.path.basename(entry["path"])[:27] + a_rows, a_cols = entry["shape_a"] + b_cols = entry["shape_b"][1] + print( + f"{name:<28} {entry.get('input_mode', 'N/A'):<10} {a_rows:>7} {a_cols:>7} {b_cols:>7} " + f"{entry['nnz_a']:>10} {entry['nnz_b']:>10} {str(entry['nnz_c'] if entry['nnz_c'] is not None else 'N/A'):>10} " + f"{_fmt_ms(entry.get('triton_ms')):>14} {_fmt_ms(entry.get('cusparse_ms')):>13} {_fmt_ms(entry.get('pytorch_ms')):>11} " + f"{_fmt_speedup(entry.get('cusparse_ms'), entry.get('triton_ms')):>7} {_fmt_speedup(entry.get('pytorch_ms'), entry.get('triton_ms')):>7} " + f"{_fmt_check(entry.get('triton_ok_pt')):>6} {_fmt_check(entry.get('triton_ok_cu')):>6} {entry.get('status', 'N/A'):>13} {str(entry.get('ref_reason_code') or 'N/A'):>14} " + f"{_fmt_err(entry.get('err_pt')):>10} {_fmt_err(entry.get('err_cu')):>10} " + f"{_fmt_err(entry.get('max_abs_err_pt')):>12} {_fmt_err(entry.get('max_rel_err_pt')):>12} {_fmt_err(entry.get('max_abs_err_cu')):>12} {_fmt_err(entry.get('max_rel_err_cu')):>12} " + f"{_fmt_ms(entry.get('prepare_ms')):>9} {_fmt_ms(entry.get('count_ms')):>10} {_fmt_ms(entry.get('fill_ms')):>9}" + ) + err = entry.get("error") + if err: + msg = str(err).replace("\n", " ") + if len(msg) > 320: + msg = msg[:317] + "..." + print(f" NOTE: {msg}") + pt_mode = entry.get("pt_exec_mode") + cu_mode = entry.get("cu_exec_mode") + pt_retry = entry.get("pt_retry_count") or 0 + cu_retry = entry.get("cu_retry_count") or 0 + if ( + pt_mode not in (None, "direct") + or cu_mode not in (None, "direct", "disabled") + or int(pt_retry) > 0 + or int(cu_retry) > 0 + ): + peak_rows = entry.get("ref_peak_block_rows") + fail_stage = entry.get("ref_fail_stage") + print( + f" REF: pt_mode={pt_mode or 'N/A'} cu_mode={cu_mode or 'N/A'} " + f"pt_retry={pt_retry} cu_retry={cu_retry} " + f"peak_block_rows={peak_rows if peak_rows is not None else 'N/A'} " + f"fail_stage={fail_stage or 'N/A'} " + f"pt_attempted={entry.get('attempted_modes_pt') or 'N/A'} " + f"cu_attempted={entry.get('attempted_modes_cu') or 'N/A'}" + ) + if entry.get("compare_status") not in (None, "OK"): + print(f" COMPARE: {entry.get('compare_status')}") + b_rows = ( + entry.get("bucket_nrows_short"), + entry.get("bucket_nrows_medium"), + entry.get("bucket_nrows_long"), + ) + b_ms = ( + entry.get("bucket_ms_short"), + entry.get("bucket_ms_medium"), + entry.get("bucket_ms_long"), + ) + if any(v is not None for v in (*b_rows, *b_ms, (entry.get("long_row_sliced_count")))): + print( + " PERF: " + f"bucket_nrows(s/m/l)={b_rows[0] if b_rows[0] is not None else 'N/A'}/" + f"{b_rows[1] if b_rows[1] is not None else 'N/A'}/" + f"{b_rows[2] if b_rows[2] is not None else 'N/A'} " + f"bucket_ms(s/m/l)={_fmt_ms(b_ms[0])}/{_fmt_ms(b_ms[1])}/{_fmt_ms(b_ms[2])} " + f"long_row_sliced={entry.get('long_row_sliced_count') if entry.get('long_row_sliced_count') is not None else 'N/A'}" + ) + + +def print_mtx_results(results, value_dtype, index_dtype): + _print_spgemm_mtx_header(value_dtype, index_dtype) + for entry in results: + _print_spgemm_mtx_row(entry) + print("-" * 320) + + +def run_all_dtypes_export_csv( + paths, + csv_path, + warmup=WARMUP, + iters=ITERS, + run_cusparse=True, + input_mode=DEFAULT_INPUT_MODE, + adaptive_loops=False, + target_window_seconds=TARGET_TIMED_WINDOW_SECONDS, + ref_blocked_retry=True, + ref_block_rows=DEFAULT_REF_BLOCK_ROWS, + ref_isolated_retry=True, + ref_cleanup=True, + compare_device=DEFAULT_COMPARE_DEVICE, +): + csv_path = _normalize_csv_path(csv_path) + rows = [] + for value_dtype in CSV_VALUE_DTYPES: + for index_dtype in CSV_INDEX_DTYPES: + print("=" * 180) + _print_spgemm_mtx_header(value_dtype, index_dtype) + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + input_mode=input_mode, + adaptive_loops=adaptive_loops, + target_window_seconds=target_window_seconds, + ref_blocked_retry=ref_blocked_retry, + ref_block_rows=ref_block_rows, + ref_isolated_retry=ref_isolated_retry, + ref_cleanup=ref_cleanup, + compare_device=compare_device, + on_result=_print_spgemm_mtx_row, + ) + print("-" * 320) + for entry in results: + n_rows, n_cols = entry["shape"] + rows.append( + { + "matrix": os.path.basename(entry["path"]), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": entry["nnz"], + "triton_ms": entry.get("triton_ms"), + "cusparse_ms": entry.get("cusparse_ms"), + "pytorch_ms": entry.get("pytorch_ms"), + "pt_status": _status_label(entry.get("triton_ok_pt")), + "cu_status": _status_label(entry.get("triton_ok_cu")), + "status": entry.get("status"), + "ref_reason_code": entry.get("ref_reason_code"), + "err_pt": entry.get("err_pt"), + "err_cu": entry.get("err_cu"), + "max_abs_err_pt": entry.get("max_abs_err_pt"), + "max_rel_err_pt": entry.get("max_rel_err_pt"), + "max_abs_err_cu": entry.get("max_abs_err_cu"), + "max_rel_err_cu": entry.get("max_rel_err_cu"), + "pytorch_reason": entry.get("pytorch_reason"), + "cusparse_reason": entry.get("cusparse_reason"), + "error": entry.get("error"), + "pt_exec_mode": entry.get("pt_exec_mode"), + "cu_exec_mode": entry.get("cu_exec_mode"), + "attempted_modes_pt": entry.get("attempted_modes_pt"), + "attempted_modes_cu": entry.get("attempted_modes_cu"), + "compare_status": entry.get("compare_status"), + "pt_retry_count": entry.get("pt_retry_count"), + "cu_retry_count": entry.get("cu_retry_count"), + "ref_peak_block_rows": entry.get("ref_peak_block_rows"), + "ref_fail_stage": entry.get("ref_fail_stage"), + "nnz_a": entry.get("nnz_a"), + "nnz_b": entry.get("nnz_b"), + "nnz_c": entry.get("nnz_c"), + "input_mode": entry.get("input_mode"), + "shape_a": str(entry.get("shape_a")), + "shape_b": str(entry.get("shape_b")), + "prepare_ms": entry.get("prepare_ms"), + "count_ms": entry.get("count_ms"), + "fill_ms": entry.get("fill_ms"), + "bucket_nrows_short": entry.get("bucket_nrows_short"), + "bucket_nrows_medium": entry.get("bucket_nrows_medium"), + "bucket_nrows_long": entry.get("bucket_nrows_long"), + "bucket_ms_short": entry.get("bucket_ms_short"), + "bucket_ms_medium": entry.get("bucket_ms_medium"), + "bucket_ms_long": entry.get("bucket_ms_long"), + "long_row_sliced_count": entry.get("long_row_sliced_count"), + "triton_started": entry.get("triton_started"), + "ref_started": entry.get("ref_started"), + "effective_warmup": entry.get("effective_warmup"), + "effective_iters": entry.get("effective_iters"), + "compare_device": entry.get("compare_device"), + } + ) + fieldnames = [ + "matrix", "value_dtype", "index_dtype", "n_rows", "n_cols", "nnz", + "triton_ms", "cusparse_ms", "pytorch_ms", + "pt_status", "cu_status", "status", "ref_reason_code", "err_pt", "err_cu", + "max_abs_err_pt", "max_rel_err_pt", "max_abs_err_cu", "max_rel_err_cu", + "pytorch_reason", "cusparse_reason", "error", + "pt_exec_mode", "cu_exec_mode", "attempted_modes_pt", "attempted_modes_cu", + "compare_status", "pt_retry_count", "cu_retry_count", + "ref_peak_block_rows", "ref_fail_stage", + "nnz_a", "nnz_b", "nnz_c", "input_mode", "shape_a", "shape_b", + "prepare_ms", "count_ms", "fill_ms", + "bucket_nrows_short", "bucket_nrows_medium", "bucket_nrows_long", + "bucket_ms_short", "bucket_ms_medium", "bucket_ms_long", + "long_row_sliced_count", + "triton_started", "ref_started", + "effective_warmup", "effective_iters", + "compare_device", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({key: ("" if value is None else value) for key, value in row.items()}) + print(f"Wrote {len(rows)} rows to {csv_path}") + + +def run_api_validation_checks(): + if not torch.cuda.is_available(): + print("API checks skipped: CUDA is not available.") + return 0 + device = torch.device("cuda") + a_data = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device=device) + a_indices = torch.tensor([0, 1, 1], dtype=torch.int32, device=device) + a_indptr = torch.tensor([0, 2, 3], dtype=torch.int64, device=device) + shape = (2, 2) + c_data, c_indices, c_indptr, _ = ast.flagsparse_spgemm_csr( + a_data, a_indices, a_indptr, shape, + a_data, a_indices, a_indptr, shape, + ) + negative_cases = [ + ( + "shape mismatch", + lambda: ast.flagsparse_spgemm_csr( + a_data, a_indices, a_indptr, (2, 3), + a_data, a_indices, a_indptr, shape, + ), + ValueError, + ), + ( + "dtype mismatch", + lambda: ast.flagsparse_spgemm_csr( + a_data, a_indices, a_indptr, shape, + a_data.to(torch.float64), a_indices, a_indptr, shape, + ), + TypeError, + ), + ( + "indices dtype must int32", + lambda: ast.flagsparse_spgemm_csr( + a_data, a_indices.to(torch.int64), a_indptr, shape, + a_data, a_indices, a_indptr, shape, + ), + TypeError, + ), + ( + "out data must be CUDA", + lambda: ast.flagsparse_spgemm_csr( + a_data, a_indices, a_indptr, shape, + a_data, a_indices, a_indptr, shape, + out=( + torch.empty(c_data.shape, dtype=c_data.dtype), + torch.empty(c_indices.shape, dtype=c_indices.dtype, device=device), + torch.empty(c_indptr.shape, dtype=c_indptr.dtype, device=device), + ), + ), + ValueError, + ), + ] + failed = 0 + print("-" * 96) + print("API validation checks (SpGEMM)") + print("-" * 96) + for name, fn, exc_type in negative_cases: + try: + fn() + print(f"FAIL {name:<32} expected {exc_type.__name__}") + failed += 1 + except exc_type: + print(f"PASS {name:<32} raised {exc_type.__name__}") + except Exception as exc: + print(f"FAIL {name:<32} raised {type(exc).__name__}: {exc}") + failed += 1 + + try: + out = ast.flagsparse_spgemm_csr( + a_data, a_indices, a_indptr, shape, + a_data, a_indices, a_indptr, shape, + ) + if len(out) != 4: + raise AssertionError("unexpected result tuple length") + print("PASS positive path returned CSR tuple") + except Exception as exc: + print(f"FAIL positive path raised {type(exc).__name__}: {exc}") + failed += 1 + print("-" * 96) + return failed + + +def _expand_mtx_paths(raw_paths): + paths = [] + for p in raw_paths: + if os.path.isfile(p) and p.lower().endswith(".mtx"): + paths.append(p) + elif os.path.isdir(p): + paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + seen = set() + uniq = [] + for path in paths: + ap = os.path.abspath(path) + if ap not in seen: + uniq.append(ap) + seen.add(ap) + return uniq + + +def _csr_to_cpu_payload(csr_tuple): + data, indices, indptr, shape = csr_tuple + return ( + data.detach().to("cpu"), + indices.detach().to("cpu"), + indptr.detach().to("cpu"), + (int(shape[0]), int(shape[1])), + ) + + +def _csr_to_device_payload(csr_tuple, device): + data, indices, indptr, shape = csr_tuple + return ( + data.to(device), + indices.to(device), + indptr.to(device), + (int(shape[0]), int(shape[1])), + ) + + +def _convert_result_for_compare(csr_tuple, compare_device, device=None): + if csr_tuple is None: + return None + if compare_device == "cpu": + if csr_tuple[0].device.type == "cpu": + return csr_tuple + return _csr_to_cpu_payload(csr_tuple) + if compare_device == "gpu": + if csr_tuple[0].device.type == "cuda": + return csr_tuple + if device is None: + raise ValueError("device is required when converting CPU CSR payload back to CUDA") + return _csr_to_device_payload(csr_tuple, device) + raise ValueError(f"unsupported compare_device: {compare_device}") + + +def _run_reference_worker(args): + if not torch.cuda.is_available(): + payload = { + "success": False, + "reason": "CUDA is not available in worker", + "fail_stage": "direct", + "exec_mode": "direct", + } + torch.save(payload, args._worker_output) + return 1 + + value_dtype = torch.float32 if args.dtype == "float32" else torch.float64 + device = torch.device("cuda") + a_data, a_indices, a_indptr, a_shape = load_mtx_to_csr_torch( + args._worker_mtx, dtype=value_dtype, device=device + ) + a_indices = a_indices.to(torch.int32) + resolved_mode = _resolve_input_mode(args._worker_input_mode, a_shape) + b_data, b_indices, b_indptr, b_shape = _build_spgemm_rhs( + a_data, a_indices, a_indptr, a_shape, resolved_mode + ) + b_indices = b_indices.to(torch.int32) + ref_state = _run_reference_with_retries( + backend=args._ref_worker, + a_data=a_data, + a_indices=a_indices, + a_indptr=a_indptr, + a_shape=a_shape, + b_data=b_data, + b_indices=b_indices, + b_indptr=b_indptr, + b_shape=b_shape, + warmup=max(0, int(args.warmup)), + iters=max(1, int(args.iters)), + blocked_retry=not args._worker_no_blocked, + block_rows=max(0, int(args._worker_block_rows)), + isolated_retry=False, + ref_cleanup=not args._worker_no_cleanup, + mtx_path=args._worker_mtx, + value_dtype=value_dtype, + input_mode=resolved_mode, + ) + payload = { + "success": bool(ref_state.get("success")), + "reason": ref_state.get("reason"), + "fail_stage": ref_state.get("fail_stage"), + "exec_mode": ref_state.get("exec_mode"), + "format": ref_state.get("format"), + "ms": ref_state.get("ms"), + "retry_count": ref_state.get("retry_count", 0), + "peak_block_rows": ref_state.get("peak_block_rows"), + "result": None, + } + if ref_state.get("success") and ref_state.get("result") is not None: + payload["result"] = _csr_to_cpu_payload(ref_state["result"]) + torch.save(payload, args._worker_output) + return 0 if payload["success"] else 1 + + +def main(): + parser = argparse.ArgumentParser(description="FlagSparse SpGEMM CSR tests") + parser.add_argument("mtx", nargs="*", help=".mtx files or directories") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float64"]) + parser.add_argument("--index-dtype", type=str, default="int32", choices=["int32"]) + parser.add_argument("--warmup", type=int, default=WARMUP) + parser.add_argument("--iters", type=int, default=ITERS) + parser.add_argument( + "--adaptive-loops", + action="store_true", + help="enable adaptive effective_warmup/effective_iters based on first-call runtime", + ) + parser.add_argument( + "--target-window-seconds", + type=float, + default=TARGET_TIMED_WINDOW_SECONDS, + help="adaptive target runtime window per matrix (only used with --adaptive-loops)", + ) + parser.add_argument("--no-cusparse", action="store_true") + parser.add_argument( + "--ref-blocked-retry", + dest="ref_blocked_retry", + action="store_true", + default=True, + help="enable blocked retry for torch/cupy references when direct call hits resource/OOM", + ) + parser.add_argument( + "--no-ref-blocked-retry", + dest="ref_blocked_retry", + action="store_false", + help="disable blocked retry for references", + ) + parser.add_argument( + "--ref-block-rows", + type=str, + default="auto", + help="row block size for blocked reference retry, integer or 'auto'", + ) + parser.add_argument( + "--ref-isolated-retry", + dest="ref_isolated_retry", + action="store_true", + default=True, + help="enable isolated subprocess retry for failed references", + ) + parser.add_argument( + "--no-ref-isolated-retry", + dest="ref_isolated_retry", + action="store_false", + help="disable isolated subprocess retry for failed references", + ) + parser.add_argument( + "--ref-cleanup", + dest="ref_cleanup", + action="store_true", + default=True, + help="enable allocator cleanup between reference attempts", + ) + parser.add_argument( + "--no-ref-cleanup", + dest="ref_cleanup", + action="store_false", + help="disable allocator cleanup between reference attempts", + ) + parser.add_argument("--csv", type=str, default=None, metavar="FILE") + parser.add_argument( + "--run-api-checks", + action="store_true", + help="run API validation checks before matrix benchmark (disabled by default)", + ) + parser.add_argument("--skip-api-checks", action="store_true", help=argparse.SUPPRESS) + parser.add_argument( + "--input-mode", + type=str, + default=DEFAULT_INPUT_MODE, + choices=["auto", "a_equals_b", "a_at"], + help="extra option: auto(square->A@A, rectangular->A@A^T) to avoid shape mismatch on non-square matrices", + ) + parser.add_argument( + "--compare-device", + type=str, + default=DEFAULT_COMPARE_DEVICE, + choices=["cpu", "gpu"], + help="where to compare Triton/reference results; cpu mode offloads each result before compare", + ) + parser.add_argument("--_ref-worker", type=str, choices=["torch", "cupy"], default=None, help=argparse.SUPPRESS) + parser.add_argument("--_worker-mtx", type=str, default=None, help=argparse.SUPPRESS) + parser.add_argument("--_worker-output", type=str, default=None, help=argparse.SUPPRESS) + parser.add_argument("--_worker-block-rows", type=int, default=0, help=argparse.SUPPRESS) + parser.add_argument("--_worker-input-mode", type=str, default=DEFAULT_INPUT_MODE, choices=["auto", "a_equals_b", "a_at"], help=argparse.SUPPRESS) + parser.add_argument("--_worker-no-blocked", action="store_true", help=argparse.SUPPRESS) + parser.add_argument("--_worker-no-cleanup", action="store_true", help=argparse.SUPPRESS) + args = parser.parse_args() + + if args._ref_worker is not None: + if not args._worker_mtx or not args._worker_output: + raise SystemExit("worker mode requires --_worker-mtx and --_worker-output") + raise SystemExit(_run_reference_worker(args)) + + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + + if args.run_api_checks and not args.skip_api_checks: + failed = run_api_validation_checks() + if failed > 0: + raise SystemExit(1) + + value_dtype = torch.float32 if args.dtype == "float32" else torch.float64 + index_dtype = torch.int32 + ref_block_rows = _parse_ref_block_rows(args.ref_block_rows) + paths = _expand_mtx_paths(args.mtx) + if not paths and not args.csv: + print("No .mtx files given. Use: python test_spgemm.py [file2.mtx ...] or ") + print("Or run all dtypes and export CSV: python test_spgemm.py --csv results.csv") + return + + if args.csv is not None: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found. Specify files or a directory.") + return + csv_path = _normalize_csv_path(args.csv) + print("=" * 120) + print("FLAGSPARSE SpGEMM - f32/f64 with int32, export to CSV") + print("=" * 120) + print( + f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)} | " + f"input_mode: {args.input_mode} | adaptive_loops: {args.adaptive_loops} | " + f"ref_blocked_retry: {args.ref_blocked_retry} | ref_isolated_retry: {args.ref_isolated_retry} | " + f"ref_block_rows: {args.ref_block_rows} | compare_device: {args.compare_device} | CSV: {csv_path}" + ) + run_all_dtypes_export_csv( + paths, + csv_path, + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + input_mode=args.input_mode, + adaptive_loops=args.adaptive_loops, + target_window_seconds=args.target_window_seconds, + ref_blocked_retry=args.ref_blocked_retry, + ref_block_rows=ref_block_rows, + ref_isolated_retry=args.ref_isolated_retry, + ref_cleanup=args.ref_cleanup, + compare_device=args.compare_device, + ) + return + + print("=" * 160) + print("FLAGSPARSE SpGEMM - SuiteSparse .mtx batch (CSR)") + print("=" * 160) + print(f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)}") + print( + f"dtype: {args.dtype} index_dtype: {args.index_dtype} warmup: {args.warmup} " + f"iters: {args.iters} adaptive_loops: {args.adaptive_loops} input_mode: {args.input_mode} " + f"ref_blocked_retry: {args.ref_blocked_retry} ref_isolated_retry: {args.ref_isolated_retry} " + f"ref_block_rows: {args.ref_block_rows} compare_device: {args.compare_device}" + ) + print() + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + input_mode=args.input_mode, + adaptive_loops=args.adaptive_loops, + target_window_seconds=args.target_window_seconds, + ref_blocked_retry=args.ref_blocked_retry, + ref_block_rows=ref_block_rows, + ref_isolated_retry=args.ref_isolated_retry, + ref_cleanup=args.ref_cleanup, + compare_device=args.compare_device, + ) + print_mtx_results(results, value_dtype, index_dtype) + + +if __name__ == "__main__": + main() diff --git a/tests/test_spmm.py b/tests/test_spmm.py new file mode 100644 index 0000000..0fdabae --- /dev/null +++ b/tests/test_spmm.py @@ -0,0 +1,1112 @@ +""" +SpMM tests: load SuiteSparse .mtx, batch run, output error and performance. +Supports: multi .mtx files, value_dtype / index_dtype, CSV export, synthetic cases, +API validation checks, and PyTorch / CuPy comparison baselines. + +This test module targets the current FlagSparse CSR SpMM implementation, which maps +AlphaSparse CSR ALG1 (row-balance / seq-reduce) onto Triton for the CSR + non-transpose ++ row-major dense-B/C subset. +""" +import argparse +import csv +import glob +import os +import sys +import time +from pathlib import Path + +import torch + +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + +import flagsparse as ast +import flagsparse.sparse_operations.spmm_csr as ast_ops + + + +VALUE_DTYPES = [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, +] +INDEX_DTYPES = [torch.int32, torch.int64] +CSV_VALUE_DTYPES = [torch.float32, torch.float64] +CSV_INDEX_DTYPES = [torch.int32] +TEST_CASES = [ + (512, 512, 4096, 16), + (1024, 1024, 16384, 32), + (2048, 2048, 65536, 64), + (4096, 4096, 131072, 64), +] +ALG1_TILE_CASES = [ + (256, 256, 4096, 4), + (256, 256, 4096, 5), + (256, 256, 4096, 12), + (256, 256, 4096, 24), + (256, 256, 4096, 48), + (256, 256, 4096, 96), +] +WARMUP = 10 +ITERS = 50 +DEFAULT_BLOCK_N = None +DEFAULT_BLOCK_NNZ = None +DEFAULT_MAX_SEGMENTS = None +LONG_ROW_NNZ = 1536 +LONG_ROW_SHAPE = (2, 2048) +LONG_ROW_DENSE_COLS = 48 + + + + +def _dtype_name(dtype): + return str(dtype).replace("torch.", "") + + +def _fmt_ms(value): + return "N/A" if value is None else f"{value:.4f}" + + +def _fmt_speedup(other_ms, triton_ms): + if other_ms is None or triton_ms is None or triton_ms <= 0: + return "N/A" + return f"{other_ms / triton_ms:.2f}x" + + +def _fmt_err(value): + return "N/A" if value is None else f"{value:.2e}" + + +def _fmt_check(value): + if value is None: + return "N/A" + return "PASS" if value else "FAIL" + +def _status_label(value): + if value is None: + return "N/A" + return "PASS" if value else "FAIL" + +def _normalize_csv_path(csv_path): + csv_path = str(csv_path) + if not csv_path.lower().endswith(".csv"): + csv_path = f"{csv_path}.csv" + parent = os.path.dirname(os.path.abspath(csv_path)) + if parent: + os.makedirs(parent, exist_ok=True) + return csv_path + +def _fmt_launch_value(value): + return "auto" if value is None else str(value) + + +def _build_values(length, value_dtype, device): + shape = (length,) + if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + return torch.randn(shape, dtype=value_dtype, device=device) + if value_dtype == torch.complex64: + real = torch.randn(shape, dtype=torch.float32, device=device) + imag = torch.randn(shape, dtype=torch.float32, device=device) + return torch.complex(real, imag) + if value_dtype == torch.complex128: + real = torch.randn(shape, dtype=torch.float64, device=device) + imag = torch.randn(shape, dtype=torch.float64, device=device) + return torch.complex(real, imag) + raise TypeError(f"Unsupported value dtype: {value_dtype}") + + +def _build_dense_matrix(n_rows, n_cols, value_dtype, device): + shape = (n_rows, n_cols) + if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + return torch.randn(shape, dtype=value_dtype, device=device) + if value_dtype == torch.complex64: + real = torch.randn(shape, dtype=torch.float32, device=device) + imag = torch.randn(shape, dtype=torch.float32, device=device) + return torch.complex(real, imag) + if value_dtype == torch.complex128: + real = torch.randn(shape, dtype=torch.float64, device=device) + imag = torch.randn(shape, dtype=torch.float64, device=device) + return torch.complex(real, imag) + raise TypeError(f"Unsupported value dtype: {value_dtype}") + + +def _tolerance_for_dtype(value_dtype): + if value_dtype == torch.float16: + return 2e-3, 2e-3 + if value_dtype == torch.bfloat16: + return 1e-1, 1e-1 + if value_dtype in (torch.float32, torch.complex64): + return 1e-6, 1e-5 + if value_dtype in (torch.float64, torch.complex128): + return 1e-10, 1e-8 + # If we ever need to mirror the looser SpMV test-script policy instead of the + # stricter library defaults, switch the float32/complex64 branch to + # `return 1e-4, 1e-2` and the float64/complex128 branch to + # `return 1e-12, 1e-10`. + return 1e-6, 1e-5 + + +def _scaled_allclose_error(candidate, reference, value_dtype=None): + if candidate.numel() == 0: + return 0.0 + dtype = reference.dtype if value_dtype is None else value_dtype + atol, rtol = _tolerance_for_dtype(dtype) + diff = torch.abs(candidate - reference) + denom = atol + rtol * torch.abs(reference) + return float(torch.max(diff / denom).item()) + +def load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): + """Load SuiteSparse / Matrix Market .mtx file into CSR as torch tensors.""" + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(file_path, "r", encoding="utf-8") as handle: + lines = handle.readlines() + + mm_field = "real" + mm_symmetry = "general" + data_lines = [] + header_info = None + for line in lines: + line = line.strip() + if line.startswith("%%MatrixMarket"): + parts = line.split() + if len(parts) >= 5: + mm_field = parts[3].lower() + mm_symmetry = parts[4].lower() + continue + if line.startswith("%"): + continue + if not header_info and line: + parts = line.split() + n_rows = int(parts[0]) + n_cols = int(parts[1]) + nnz = int(parts[2]) if len(parts) > 2 else 0 + header_info = (n_rows, n_cols, nnz) + continue + if line: + data_lines.append(line) + + if header_info is None: + raise ValueError(f"Cannot parse .mtx header: {file_path}") + + n_rows, n_cols, nnz = header_info + if nnz == 0: + data = torch.tensor([], dtype=dtype, device=device) + indices = torch.tensor([], dtype=torch.int64, device=device) + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + + if mm_field == "complex" and dtype not in (torch.complex64, torch.complex128): + raise TypeError( + f"Matrix Market file {file_path} stores complex values but requested dtype {dtype}" + ) + + is_pattern = mm_field == "pattern" + is_complex = mm_field == "complex" + is_symmetric = mm_symmetry == "symmetric" + is_skew = mm_symmetry == "skew-symmetric" + is_hermitian = mm_symmetry == "hermitian" + + entries = {} + + def _accumulate(row_idx, col_idx, value): + key = (row_idx, col_idx) + entries[key] = entries.get(key, 0.0) + value + + for line in data_lines[:nnz]: + parts = line.split() + if len(parts) < 2: + continue + row_idx = int(parts[0]) - 1 + col_idx = int(parts[1]) - 1 + if not (0 <= row_idx < n_rows and 0 <= col_idx < n_cols): + continue + + if is_pattern: + value = 1.0 + elif is_complex: + if len(parts) < 4: + raise ValueError(f"Complex Matrix Market entry is missing an imaginary part: {line}") + value = complex(float(parts[2]), float(parts[3])) + else: + if len(parts) < 3: + raise ValueError(f"Matrix Market entry is missing a numeric value: {line}") + value = float(parts[2]) + + _accumulate(row_idx, col_idx, value) + if row_idx != col_idx: + if is_symmetric and 0 <= col_idx < n_rows and 0 <= row_idx < n_cols: + _accumulate(col_idx, row_idx, value) + elif is_skew and 0 <= col_idx < n_rows and 0 <= row_idx < n_cols: + _accumulate(col_idx, row_idx, -value) + elif is_hermitian and 0 <= col_idx < n_rows and 0 <= row_idx < n_cols: + twin = value.conjugate() if isinstance(value, complex) else value + _accumulate(col_idx, row_idx, twin) + + sorted_entries = sorted(entries.items(), key=lambda item: item[0]) + cols_sorted = [] + vals_sorted = [] + indptr_list = [0] + current_row = 0 + for (row_idx, col_idx), value in sorted_entries: + while current_row < row_idx: + indptr_list.append(len(cols_sorted)) + current_row += 1 + cols_sorted.append(col_idx) + vals_sorted.append(value) + while len(indptr_list) < n_rows + 1: + indptr_list.append(len(cols_sorted)) + + data = torch.tensor(vals_sorted, dtype=dtype, device=device) + indices = torch.tensor(cols_sorted, dtype=torch.int64, device=device) + indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + +def _build_pytorch_reference(data, indices, indptr, shape, B): + device = data.device + n_rows = shape[0] + indptr64 = indptr.to(torch.int64) + indices64 = indices.to(torch.int64) + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + + try: + csr_pt = torch.sparse_csr_tensor(indptr64, indices64, data, size=shape, device=device) + timing_op = lambda: torch.sparse.mm(csr_pt, B) + if data.dtype in (torch.float16, torch.bfloat16): + csr_ref = torch.sparse_csr_tensor(indptr64, indices64, data.to(torch.float32), size=shape, device=device) + ref = torch.sparse.mm(csr_ref, B.to(torch.float32)).to(data.dtype) + elif data.dtype == torch.float32: + csr_ref = torch.sparse_csr_tensor(indptr64, indices64, data.to(torch.float64), size=shape, device=device) + ref = torch.sparse.mm(csr_ref, B.to(torch.float64)).to(data.dtype) + elif data.dtype == torch.complex64: + csr_ref = torch.sparse_csr_tensor(indptr64, indices64, data.to(torch.complex128), size=shape, device=device) + ref = torch.sparse.mm(csr_ref, B.to(torch.complex128)).to(data.dtype) + else: + ref = torch.sparse.mm(csr_pt, B) + return ref, timing_op, "CSR" + except Exception: + coo = torch.sparse_coo_tensor( + torch.stack([row_ind, indices64]), + data, + shape, + device=device, + ).coalesce() + timing_op = lambda: torch.sparse.mm(coo, B) + if data.dtype in (torch.float16, torch.bfloat16): + ref = torch.sparse.mm(coo.to(torch.float32), B.to(torch.float32)).to(data.dtype) + elif data.dtype == torch.float32: + ref = torch.sparse.mm(coo.to(torch.float64), B.to(torch.float64)).to(data.dtype) + elif data.dtype == torch.complex64: + ref = torch.sparse.mm(coo.to(torch.complex128), B.to(torch.complex128)).to(data.dtype) + else: + ref = torch.sparse.mm(coo, B) + return ref, timing_op, "COO" +def _benchmark_triton_spmm( + data, + indices, + indptr, + B, + shape, + warmup, + iters, + block_n=None, + block_nnz=None, + max_segments=None, +): + kwargs = { + "data": data, + "indices": indices, + "indptr": indptr, + "B": B, + "shape": shape, + "block_n": block_n, + "block_nnz": block_nnz, + "max_segments": max_segments, + } + torch.cuda.synchronize() + t0 = time.perf_counter() + _ = ast.flagsparse_spmm_csr(**kwargs) + torch.cuda.synchronize() + first_call_ms = (time.perf_counter() - t0) * 1000.0 + result, steady_ms = ast_ops._benchmark_cuda_op( + lambda: ast.flagsparse_spmm_csr(**kwargs), + warmup=warmup, + iters=iters, + ) + return result, steady_ms, first_call_ms + +def _assert_spmm_matches_reference( + data, + indices, + indptr, + B, + shape, + value_dtype, + block_n=None, + block_nnz=None, + max_segments=None, + out=None, +): + result = ast.flagsparse_spmm_csr( + data, + indices, + indptr, + B, + shape, + block_n=block_n, + block_nnz=block_nnz, + max_segments=max_segments, + out=out, + ) + ref_C, _, _ = _build_pytorch_reference(data, indices, indptr, shape, B) + atol, rtol = _tolerance_for_dtype(value_dtype) + if not torch.allclose(result, ref_C, atol=atol, rtol=rtol): + metrics = ast_ops._spmm_validation_metrics(result, ref_C) + raise AssertionError( + "reference mismatch: " + f"err={_scaled_allclose_error(result, ref_C, value_dtype):.3e}, " + f"max_abs={metrics['max_abs_error']:.3e}, " + f"atol={atol:.3e}, " + f"rtol={rtol:.3e}" + ) + if out is not None and result.data_ptr() != out.data_ptr(): + raise AssertionError("flagsparse_spmm_csr did not return the provided out tensor") + return result, ref_C +def _build_long_row_case(value_dtype, index_dtype, device, n_dense_cols=LONG_ROW_DENSE_COLS): + n_rows, n_cols = LONG_ROW_SHAPE + row0_cols = torch.arange(LONG_ROW_NNZ, dtype=torch.int64, device=device) + row1_cols = torch.tensor([7, 129, 511, 1024], dtype=torch.int64, device=device) + indices = torch.cat([row0_cols, row1_cols]).to(index_dtype) + data = _build_values(indices.numel(), value_dtype, device) + indptr = torch.tensor( + [0, LONG_ROW_NNZ, LONG_ROW_NNZ + row1_cols.numel()], + dtype=torch.int64, + device=device, + ) + B = _build_dense_matrix(n_cols, n_dense_cols, value_dtype, device) + return data, indices, indptr, B, (n_rows, n_cols) + +def run_one_mtx( + mtx_path, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=10, + iters=50, + run_cusparse=True, + n_dense_cols=32, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + max_segments=DEFAULT_MAX_SEGMENTS, +): + """Run SpMM on one .mtx and compare against PyTorch/CuPy baselines.""" + device = torch.device("cuda") + data, indices, indptr, shape = load_mtx_to_csr_torch(mtx_path, dtype=value_dtype, device=device) + indices = indices.to(index_dtype) + n_rows, n_cols = shape + nnz = data.numel() + B = _build_dense_matrix(n_cols, n_dense_cols, value_dtype, device) + atol, rtol = _tolerance_for_dtype(value_dtype) + + result = { + "path": mtx_path, + "shape": shape, + "nnz": nnz, + "dense_cols": n_dense_cols, + "error": None, + "triton_ms": None, + "triton_first_call_ms": None, + "cusparse_ms": None, + "pytorch_ms": None, + "err_pt": None, + "err_cu": None, + "triton_abs_err": None, + "cusparse_abs_err": None, + "triton_relative_error_diag": None, + "cusparse_relative_error_diag": None, + "triton_ok_pt": None, + "triton_ok_cu": None, + "cusparse_reason": None, + "pytorch_reason": None, + "pytorch_format": None, + "status": "UNKNOWN", + } + + try: + ref_C, pytorch_op, pytorch_format = _build_pytorch_reference(data, indices, indptr, shape, B) + result["pytorch_format"] = pytorch_format + except Exception as exc: + result["error"] = f"ref: {exc}" + result["status"] = "REF_FAIL" + return result + + triton_C = None + try: + triton_C, triton_ms, triton_first_call_ms = _benchmark_triton_spmm( + data, + indices, + indptr, + B, + shape, + warmup=warmup, + iters=iters, + block_n=block_n, + block_nnz=block_nnz, + max_segments=max_segments, + ) + result["triton_ms"] = triton_ms + result["triton_first_call_ms"] = triton_first_call_ms + except Exception as exc: + # Do not return: still time PyTorch / CuPy so CSV shows baseline ms when Triton fails. + result["error"] = f"triton: {exc}" + result["triton_ok_pt"] = False + + if triton_C is not None: + triton_metrics = ast_ops._spmm_validation_metrics(triton_C, ref_C) + result["triton_abs_err"] = triton_metrics["max_abs_error"] + result["triton_relative_error_diag"] = triton_metrics["max_relative_error"] + result["err_pt"] = _scaled_allclose_error(triton_C, ref_C, value_dtype) + result["triton_ok_pt"] = torch.allclose(triton_C, ref_C, atol=atol, rtol=rtol) + else: + result["triton_ok_pt"] = False + + try: + _, result["pytorch_ms"] = ast_ops._benchmark_cuda_op( + pytorch_op, + warmup=warmup, + iters=iters, + ) + except Exception as exc: + result["pytorch_reason"] = str(exc) + + _cupy_supported_dtypes = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ) + if run_cusparse: + if value_dtype not in _cupy_supported_dtypes: + result["cusparse_reason"] = "float16/bfloat16 not supported by CuPy sparse; skipped" + else: + try: + import cupy as cp + import cupyx.scipy.sparse as cpx + + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data)) + ind_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64))) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr)) + B_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(B)) + A_csr = cpx.csr_matrix((data_cp, ind_cp, ptr_cp), shape=shape) + + torch.cuda.synchronize() + for _ in range(warmup): + _ = A_csr @ B_cp + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + _ = A_csr @ B_cp + end.record() + torch.cuda.synchronize() + result["cusparse_ms"] = start.elapsed_time(end) / iters + + cs_C = A_csr @ B_cp + cs_C_t = torch.utils.dlpack.from_dlpack(cs_C.toDlpack()) + cusparse_metrics = ast_ops._spmm_validation_metrics(cs_C_t, ref_C) + result["cusparse_abs_err"] = cusparse_metrics["max_abs_error"] + result["cusparse_relative_error_diag"] = cusparse_metrics["max_relative_error"] + if triton_C is not None: + result["err_cu"] = _scaled_allclose_error(triton_C, cs_C_t, value_dtype) + result["triton_ok_cu"] = torch.allclose(triton_C, cs_C_t, atol=atol, rtol=rtol) + except Exception as exc: + result["cusparse_ms"] = None + result["err_cu"] = None + result["cusparse_abs_err"] = None + result["cusparse_relative_error_diag"] = None + result["triton_ok_cu"] = None + result["cusparse_reason"] = str(exc) + + result["status"] = "PASS" if (result["triton_ok_pt"] or result["triton_ok_cu"]) else "FAIL" + return result +def run_mtx_batch( + mtx_paths, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=10, + iters=50, + run_cusparse=True, + n_dense_cols=32, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + max_segments=DEFAULT_MAX_SEGMENTS, + on_result=None, +): + results = [] + for path in mtx_paths: + entry = run_one_mtx( + path, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + n_dense_cols=n_dense_cols, + block_n=block_n, + block_nnz=block_nnz, + max_segments=max_segments, + ) + results.append(entry) + if on_result is not None: + on_result(entry) + return results + + +def _print_spmm_csr_mtx_header(value_dtype, index_dtype): + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)}" + ) + print("Formats: FlagSparse=CSR ALG1, cuSPARSE=CSR dense-mm, PyTorch=CSR or COO.") + print("Timing stays in native dtype. For float32, correctness references use float64 compute then cast.") + print("PT/CU show per-reference correctness. Err(PT)/Err(CU)=max(|diff| / (atol + rtol*|ref|)).") + print("For float32, PT checks the float64-based correctness reference while CU checks consistency with native cuSPARSE float32, so PT and CU may differ.") + print("-" * 186) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} {'DenseN':>8} " + f"{'FlagSparse(ms)':>14} {'cuSPARSE(ms)':>13} {'PyTorch(ms)':>11} " + f"{'FS/CU':>7} {'FS/PT':>7} {'PT':>6} {'CU':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 186) + + +def _print_spmm_csr_mtx_row(entry): + name = os.path.basename(entry["path"])[:27] + n_rows, n_cols = entry["shape"] + triton_ms = entry.get("triton_ms") + cu_ms = entry.get("cusparse_ms") + pt_ms = entry.get("pytorch_ms") + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {entry['nnz']:>10} {entry['dense_cols']:>8} " + f"{_fmt_ms(triton_ms):>14} {_fmt_ms(cu_ms):>13} {_fmt_ms(pt_ms):>11} " + f"{_fmt_speedup(cu_ms, triton_ms):>7} {_fmt_speedup(pt_ms, triton_ms):>7} " + f"{_fmt_check(entry.get('triton_ok_pt')):>6} {_fmt_check(entry.get('triton_ok_cu')):>6} " + f"{_fmt_err(entry.get('err_pt')):>10} {_fmt_err(entry.get('err_cu')):>10}" + ) + err = entry.get("error") + if err: + msg = str(err).replace("\n", " ") + if len(msg) > 200: + msg = msg[:197] + "..." + print(f" NOTE: {msg}") + + +def print_mtx_results(results, value_dtype, index_dtype): + _print_spmm_csr_mtx_header(value_dtype, index_dtype) + for entry in results: + _print_spmm_csr_mtx_row(entry) + print("-" * 186) + + + +def run_all_dtypes_export_csv( + paths, + csv_path, + warmup=10, + iters=50, + run_cusparse=True, + n_dense_cols=32, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + max_segments=DEFAULT_MAX_SEGMENTS, +): + csv_path = _normalize_csv_path(csv_path) + rows = [] + for value_dtype in CSV_VALUE_DTYPES: + for index_dtype in CSV_INDEX_DTYPES: + print("=" * 150) + _print_spmm_csr_mtx_header(value_dtype, index_dtype) + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + n_dense_cols=n_dense_cols, + block_n=block_n, + block_nnz=block_nnz, + max_segments=max_segments, + on_result=_print_spmm_csr_mtx_row, + ) + print("-" * 186) + for entry in results: + n_rows, n_cols = entry["shape"] + rows.append({ + "matrix": os.path.basename(entry["path"]), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": entry["nnz"], + "triton_ms": entry.get("triton_ms"), + "cusparse_ms": entry.get("cusparse_ms"), + "pytorch_ms": entry.get("pytorch_ms"), + "pt_status": _status_label(entry.get("triton_ok_pt")), + "cu_status": _status_label(entry.get("triton_ok_cu")), + "status": ( + "PASS" + if (entry.get("triton_ok_pt") or entry.get("triton_ok_cu")) + else "FAIL" + ), + "err_pt": entry.get("err_pt"), + "err_cu": entry.get("err_cu"), + "error": entry.get("error"), + }) + fieldnames = [ + "matrix", "value_dtype", "index_dtype", "n_rows", "n_cols", "nnz", + "triton_ms", "cusparse_ms", "pytorch_ms", + "pt_status", "cu_status", "status", "err_pt", "err_cu", "error", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({key: ("" if value is None else value) for key, value in row.items()}) + print(f"Wrote {len(rows)} rows to {csv_path}") + +def run_api_validation_checks(): + if not torch.cuda.is_available(): + print("API checks skipped: CUDA is not available.") + return 0 + + device = torch.device("cuda") + data = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device=device) + indices = torch.tensor([0, 1, 1], dtype=torch.int32, device=device) + indptr = torch.tensor([0, 2, 3], dtype=torch.int64, device=device) + B = torch.randn((2, 4), dtype=torch.float32, device=device) + long_data, long_indices, long_indptr, long_B, long_shape = _build_long_row_case( + torch.float32, torch.int32, device + ) + + negative_cases = [ + ("shape must be 2D", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, B, (2,)), ValueError), + ("B must be 2D", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, B[0], (2, 2)), ValueError), + ("dtype mismatch", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, B.to(torch.float64), (2, 2)), TypeError), + ("shape mismatch", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, torch.randn((3, 4), dtype=torch.float32, device=device), (2, 2)), ValueError), + ("indptr length mismatch", lambda: ast.flagsparse_spmm_csr(data, indices, indptr[:-1], B, (2, 2)), ValueError), + ("indptr must start at 0", lambda: ast.flagsparse_spmm_csr(data, indices, torch.tensor([1, 2, 3], dtype=torch.int64, device=device), B, (2, 2)), ValueError), + ("indptr last must equal nnz", lambda: ast.flagsparse_spmm_csr(data, indices, torch.tensor([0, 2, 2], dtype=torch.int64, device=device), B, (2, 2)), ValueError), + ("indptr monotonic", lambda: ast.flagsparse_spmm_csr(data, indices, torch.tensor([0, 3, 2], dtype=torch.int64, device=device), B, (2, 2)), ValueError), + ("indices out of range", lambda: ast.flagsparse_spmm_csr(data, torch.tensor([0, 3, 1], dtype=torch.int32, device=device), indptr, B, (2, 2)), IndexError), + ("block_n positive", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, B, (2, 2), block_n=0), ValueError), + ("block_nnz positive", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, B, (2, 2), block_nnz=0), ValueError), + ("max_segments positive", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, B, (2, 2), max_segments=0), ValueError), + ("out shape mismatch", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, B, (2, 2), out=torch.empty((3, 4), dtype=torch.float32, device=device)), ValueError), + ("out device mismatch", lambda: ast.flagsparse_spmm_csr(data, indices, indptr, B, (2, 2), out=torch.empty((2, 4), dtype=torch.float32)), ValueError), + ( + "segment overflow override", + lambda: ast.flagsparse_spmm_csr(long_data, long_indices, long_indptr, long_B, long_shape, block_nnz=128, max_segments=4), + ValueError, + ), + ] + + failed = 0 + print("-" * 96) + print("API validation checks") + print("-" * 96) + for name, fn, exc_type in negative_cases: + try: + fn() + print(f"FAIL {name:<32} expected {exc_type.__name__}") + failed += 1 + except exc_type: + print(f"PASS {name:<32} raised {exc_type.__name__}") + except Exception as exc: + print(f"FAIL {name:<32} raised {type(exc).__name__}: {exc}") + failed += 1 + + positive_checks = [] + + def _positive_out_path(): + out = torch.empty((2, 4), dtype=torch.float32, device=device) + _assert_spmm_matches_reference(data, indices, indptr, B, (2, 2), torch.float32, out=out) + + positive_checks.append(("out path success", _positive_out_path)) + + def _positive_empty_matrix(): + empty_data = torch.tensor([], dtype=torch.float32, device=device) + empty_indices = torch.tensor([], dtype=torch.int32, device=device) + empty_indptr = torch.zeros(3, dtype=torch.int64, device=device) + dense = torch.randn((2, 4), dtype=torch.float32, device=device) + result, _ = _assert_spmm_matches_reference( + empty_data, + empty_indices, + empty_indptr, + dense, + (2, 2), + torch.float32, + ) + if result.shape != (2, 4): + raise AssertionError(f"unexpected empty-matrix result shape: {tuple(result.shape)}") + + positive_checks.append(("empty matrix success", _positive_empty_matrix)) + + def _positive_empty_dense_cols(): + dense = torch.empty((2, 0), dtype=torch.float32, device=device) + result, _ = _assert_spmm_matches_reference( + data, + indices, + indptr, + dense, + (2, 2), + torch.float32, + ) + if result.shape != (2, 0): + raise AssertionError(f"unexpected empty-dense result shape: {tuple(result.shape)}") + + positive_checks.append(("empty dense cols success", _positive_empty_dense_cols)) + + def _positive_noncontiguous_b(): + dense = _build_dense_matrix(4, 2, torch.float32, device).transpose(0, 1) + if dense.is_contiguous(): + raise AssertionError("expected non-contiguous test matrix") + _assert_spmm_matches_reference(data, indices, indptr, dense, (2, 2), torch.float32) + + positive_checks.append(("noncontiguous B success", _positive_noncontiguous_b)) + + def _positive_long_row_default(): + _assert_spmm_matches_reference( + long_data, + long_indices, + long_indptr, + long_B, + long_shape, + torch.float32, + ) + + positive_checks.append(("long-row default success", _positive_long_row_default)) + + + for name, fn in positive_checks: + try: + fn() + print(f"PASS {name:<32} returned correct result") + except Exception as exc: + print(f"FAIL {name:<32} raised {type(exc).__name__}: {exc}") + failed += 1 + + print("-" * 96) + return failed + + +def run_alg1_tile_branch_coverage(warmup=WARMUP, iters=ITERS, run_cusparse=True): + if not torch.cuda.is_available(): + print("ALG1 branch coverage skipped: CUDA is not available.") + return 0 + + print("=" * 132) + print("ALG1 dense-column heuristic coverage") + print("=" * 132) + print( + f"{'DenseN':>8} {'BLOCK_N':>8} {'NNZTile':>8} {'ReqSeg':>7} {'Warp':>6} {'Factor':>7} " + f"{'PyTorch(ms)':>12} {'FlagSparse(ms)':>14} {'cuSPARSE(ms)':>12} {'PT':>6} {'CU':>6} {'Err(FS)':>11}" + ) + print("-" * 132) + + failed = 0 + note = None + for n_rows, n_cols, nnz, n_dense_cols in ALG1_TILE_CASES: + result = ast.benchmark_spmm_case( + n_rows=n_rows, + n_cols=n_cols, + nnz=nnz, + n_dense_cols=n_dense_cols, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + ) + params = result["parameters"] + perf = result["performance"] + verify = result["verification"] + backend = result["backend_status"] + samples = result["samples"] + triton_ok = verify.get("triton_strict_allclose_match", verify.get("triton_match_reference")) + cusparse_ok = verify.get("cusparse_strict_allclose_match", verify.get("cusparse_match_reference")) + status = "PASS" if triton_ok and (cusparse_ok is None or cusparse_ok) else "FAIL" + if status != "PASS": + failed += 1 + if backend.get("cusparse_unavailable_reason"): + note = backend["cusparse_unavailable_reason"] + triton_err = _scaled_allclose_error(samples["triton"], samples["reference"], torch.float32) + print( + f"{n_dense_cols:>8} {params['block_n']:>8} {params['block_nnz']:>8} {params['required_segments']:>7} " + f"{params['alg1_warp_size']:>6} {params['alg1_factor']:>7} " + f"{_fmt_ms(perf.get('pytorch_ms')):>12} {_fmt_ms(perf.get('triton_ms')):>14} {_fmt_ms(perf.get('cusparse_ms')):>12} " + f"{_fmt_check(triton_ok):>6} {_fmt_check(cusparse_ok):>6} {_fmt_err(triton_err):>11}" + ) + print("-" * 132) + if note: + print(f"cuSPARSE note: {note}") + print() + return failed + +def run_comprehensive_synthetic( + warmup=WARMUP, + iters=ITERS, + run_cusparse=True, + run_api_checks=True, + run_alg1_coverage=True, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + max_segments=DEFAULT_MAX_SEGMENTS, +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + + print("=" * 144) + print("FLAGSPARSE SpMM BENCHMARK (synthetic CSR @ dense)") + print("=" * 144) + print( + f"GPU: {torch.cuda.get_device_name(0)} | Warmup: {warmup} Iters: {iters} " + f"BLOCK_N: {_fmt_launch_value(block_n)} BLOCK_NNZ: {_fmt_launch_value(block_nnz)} " + f"MAX_SEGMENTS: {_fmt_launch_value(max_segments)}" + ) + print("Formats: FlagSparse=CSR ALG1, cuSPARSE=CSR dense-mm (when supported), PyTorch=CSR or COO.") + print("For float32, PT checks the float64-based correctness reference while CU reflects native cuSPARSE float32 consistency.") + print() + + total = 0 + failed = 0 + for value_dtype in VALUE_DTYPES: + for index_dtype in INDEX_DTYPES: + print("-" * 144) + print( + f"Value dtype: {_dtype_name(value_dtype):<12} | Index dtype: {_dtype_name(index_dtype):<6}" + ) + print("-" * 144) + print( + f"{'N_rows':>7} {'N_cols':>7} {'NNZ':>10} {'DenseN':>8} {'BN':>4} {'BNNZ':>6} {'Seg':>4} " + f"{'PyTorch(ms)':>12} {'FlagSparse(ms)':>14} {'cuSPARSE(ms)':>12} {'FS/PT':>8} {'FS/CU':>8} {'PT':>6} {'CU':>6} {'Err(FS)':>11} {'Err(CU)':>12}" + ) + print("-" * 144) + combo_reason = None + for n_rows, n_cols, nnz, n_dense_cols in TEST_CASES: + result = ast.benchmark_spmm_case( + n_rows=n_rows, + n_cols=n_cols, + nnz=nnz, + n_dense_cols=n_dense_cols, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + block_n=block_n, + block_nnz=block_nnz, + max_segments=max_segments, + run_cusparse=run_cusparse, + ) + total += 1 + params = result["parameters"] + perf = result["performance"] + verify = result["verification"] + backend = result["backend_status"] + samples = result["samples"] + triton_ok = verify.get("triton_strict_allclose_match", verify.get("triton_match_reference")) + cusparse_ok = verify.get("cusparse_strict_allclose_match", verify.get("cusparse_match_reference")) + status = "PASS" if triton_ok and (cusparse_ok is None or cusparse_ok) else "FAIL" + if status != "PASS": + failed += 1 + if backend.get("cusparse_unavailable_reason"): + combo_reason = backend["cusparse_unavailable_reason"] + triton_err = _scaled_allclose_error(samples["triton"], samples["reference"], value_dtype) + cusparse_err = None + if samples.get("cusparse") is not None: + cusparse_err = _scaled_allclose_error(samples["triton"], samples["cusparse"], value_dtype) + print( + f"{n_rows:>7} {n_cols:>7} {nnz:>10} {n_dense_cols:>8} {params['block_n']:>4} {params['block_nnz']:>6} {params['required_segments']:>4} " + f"{_fmt_ms(perf.get('pytorch_ms')):>12} {_fmt_ms(perf.get('triton_ms')):>14} {_fmt_ms(perf.get('cusparse_ms')):>12} " + f"{_fmt_speedup(perf.get('pytorch_ms'), perf.get('triton_ms')):>8} {_fmt_speedup(perf.get('cusparse_ms'), perf.get('triton_ms')):>8} " + f"{_fmt_check(triton_ok):>6} {_fmt_check(cusparse_ok):>6} {_fmt_err(triton_err):>11} {_fmt_err(cusparse_err):>12}" + ) + print("-" * 144) + if combo_reason: + print(f" cuSPARSE: {combo_reason}") + print() + + alg1_failed = run_alg1_tile_branch_coverage(warmup=warmup, iters=iters, run_cusparse=run_cusparse) if run_alg1_coverage else 0 + api_failed = run_api_validation_checks() if run_api_checks else 0 + print("=" * 144) + print( + f"Total synthetic cases: {total} Failed synthetic cases: {failed} " + f"Failed ALG1 branch cases: {alg1_failed} Failed API checks: {api_failed}" + ) + print("=" * 144) + +def main(): + parser = argparse.ArgumentParser( + description="SpMM test: SuiteSparse .mtx batch run, error and performance." + ) + parser.add_argument( + "mtx", + nargs="*", + help=".mtx file path(s), or directory(ies) to glob for *.mtx", + ) + parser.add_argument( + "--synthetic", + action="store_true", + help="Run synthetic benchmark instead of .mtx", + ) + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "bfloat16", "float32", "float64", "complex64", "complex128"], + help="Value dtype (default: float32)", + ) + parser.add_argument( + "--index-dtype", + default="int32", + choices=["int32", "int64"], + help="Index dtype (default: int32)", + ) + parser.add_argument("--dense-cols", type=int, default=32, help="Dense RHS column count") + parser.add_argument( + "--block-n", + type=int, + default=DEFAULT_BLOCK_N, + help="Output column tile override (default: auto from ALG1 heuristic)", + ) + parser.add_argument( + "--block-nnz", + type=int, + default=DEFAULT_BLOCK_NNZ, + help="CSR segment width override (default: auto from ALG1 heuristic)", + ) + parser.add_argument( + "--max-segments", + type=int, + default=DEFAULT_MAX_SEGMENTS, + help="CSR segment count override (default: auto from matrix max row nnz)", + ) + parser.add_argument("--warmup", type=int, default=10, help="Warmup runs") + parser.add_argument("--iters", type=int, default=50, help="Timing iterations") + parser.add_argument("--no-cusparse", action="store_true", help="Skip cuSPARSE baseline") + parser.add_argument("--skip-api-checks", action="store_true", help="Skip API validation checks in synthetic mode") + parser.add_argument("--skip-alg1-coverage", action="store_true", help="Skip dense-column ALG1 heuristic coverage in synthetic mode") + parser.add_argument( + "--csv", + type=str, + default=None, + metavar="FILE", + help="Run float32/float64 with int32 indices on all .mtx and write results to one CSV", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "float64": torch.float64, + "complex64": torch.complex64, + "complex128": torch.complex128, + } + index_map = {"int32": torch.int32, "int64": torch.int64} + value_dtype = dtype_map[args.dtype] + index_dtype = index_map[args.index_dtype] + + if args.synthetic: + run_comprehensive_synthetic( + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + run_api_checks=not args.skip_api_checks, + run_alg1_coverage=not args.skip_alg1_coverage, + block_n=args.block_n, + block_nnz=args.block_nnz, + max_segments=args.max_segments, + ) + return + + paths = [] + for path in args.mtx: + if os.path.isfile(path) and path.endswith(".mtx"): + paths.append(path) + elif os.path.isdir(path): + paths.extend(sorted(glob.glob(os.path.join(path, "*.mtx")))) + + if not paths and not args.csv: + print("No .mtx files given. Use: python test_spmm.py [file2.mtx ...] or ") + print("Or run synthetic: python test_spmm.py --synthetic") + print("Or run all dtypes and export CSV: python test_spmm.py --csv results.csv") + return + + if args.csv is not None: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found. Specify files or a directory.") + return + csv_path = _normalize_csv_path(args.csv) + print("=" * 100) + print("FLAGSPARSE SpMM - f32/f64 with int32, export to CSV") + print("=" * 100) + print( + f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)} | DenseN: {args.dense_cols} | CSV: {csv_path}" + ) + if args.dtype != "float32" or args.index_dtype != "int32": + print("Note: --csv export ignores --dtype/--index-dtype and always writes float32/float64 with int32 indices.") + run_all_dtypes_export_csv( + paths, + csv_path, + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + n_dense_cols=args.dense_cols, + block_n=args.block_n, + block_nnz=args.block_nnz, + max_segments=args.max_segments, + ) + return + + print("=" * 140) + print("FLAGSPARSE SpMM - SuiteSparse .mtx batch (error + performance)") + print("=" * 140) + print(f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)}") + print( + f"dtype: {args.dtype} index_dtype: {args.index_dtype} dense_cols: {args.dense_cols} " + f"block_n: {_fmt_launch_value(args.block_n)} block_nnz: {_fmt_launch_value(args.block_nnz)} " + f"max_segments: {_fmt_launch_value(args.max_segments)} warmup: {args.warmup} iters: {args.iters}" + ) + print() + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + n_dense_cols=args.dense_cols, + block_n=args.block_n, + block_nnz=args.block_nnz, + max_segments=args.max_segments, + ) + print_mtx_results(results, value_dtype, index_dtype) + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_spmm_coo.py b/tests/test_spmm_coo.py new file mode 100644 index 0000000..b6d0980 --- /dev/null +++ b/tests/test_spmm_coo.py @@ -0,0 +1,1376 @@ +""" +COO SpMM tests: load SuiteSparse .mtx, batch run, output error and performance. +Supports: multi .mtx files, value_dtype / index_dtype, CSV export, synthetic cases, +API validation checks, and PyTorch / CuPy comparison baselines. + +This test module targets the current FlagSparse native COO SpMM implementation. +The default public route is a sorted row-run Triton COO kernel. A second native +atomic COO route is retained for internal parity checks and debug. +""" +import argparse +import csv +import glob +import os +import sys +import time +from pathlib import Path + +import torch + +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + +import flagsparse as ast +import flagsparse.sparse_operations.spmm_coo as ast_ops + + + +VALUE_DTYPES = [ + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, +] +INDEX_DTYPES = [torch.int32, torch.int64] +CSV_VALUE_DTYPES = [torch.float32, torch.float64] +CSV_INDEX_DTYPES = [torch.int32] +TEST_CASES = [ + (512, 512, 4096, 16), + (1024, 1024, 16384, 32), + (2048, 2048, 65536, 64), + (4096, 4096, 131072, 64), +] +COO_TILE_CASES = [ + (256, 256, 4096, 4), + (256, 256, 4096, 5), + (256, 256, 4096, 12), + (256, 256, 4096, 24), + (256, 256, 4096, 48), + (256, 256, 4096, 96), +] +WARMUP = 10 +ITERS = 50 +DEFAULT_BLOCK_N = None +DEFAULT_BLOCK_NNZ = 256 +DUPLICATE_CASE_DENSE_COLS = 48 + + + + +def _dtype_name(dtype): + return str(dtype).replace("torch.", "") + + +def _fmt_ms(value): + return "N/A" if value is None else f"{value:.4f}" + + +def _fmt_speedup(other_ms, triton_ms): + if other_ms is None or triton_ms is None or triton_ms <= 0: + return "N/A" + return f"{other_ms / triton_ms:.2f}x" + + +def _fmt_err(value): + return "N/A" if value is None else f"{value:.2e}" + + +def _fmt_check(value): + if value is None: + return "N/A" + return "PASS" if value else "FAIL" + +def _status_label(value): + if value is None: + return "N/A" + return "PASS" if value else "FAIL" + +def _normalize_csv_path(csv_path): + csv_path = str(csv_path) + if not csv_path.lower().endswith(".csv"): + csv_path = f"{csv_path}.csv" + parent = os.path.dirname(os.path.abspath(csv_path)) + if parent: + os.makedirs(parent, exist_ok=True) + return csv_path + +def _fmt_launch_value(value): + return "auto" if value is None else str(value) + + +def _build_values(length, value_dtype, device): + shape = (length,) + if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + return torch.randn(shape, dtype=value_dtype, device=device) + if value_dtype == torch.complex64: + real = torch.randn(shape, dtype=torch.float32, device=device) + imag = torch.randn(shape, dtype=torch.float32, device=device) + return torch.complex(real, imag) + if value_dtype == torch.complex128: + real = torch.randn(shape, dtype=torch.float64, device=device) + imag = torch.randn(shape, dtype=torch.float64, device=device) + return torch.complex(real, imag) + raise TypeError(f"Unsupported value dtype: {value_dtype}") + + +def _build_dense_matrix(n_rows, n_cols, value_dtype, device): + shape = (n_rows, n_cols) + if value_dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64): + return torch.randn(shape, dtype=value_dtype, device=device) + if value_dtype == torch.complex64: + real = torch.randn(shape, dtype=torch.float32, device=device) + imag = torch.randn(shape, dtype=torch.float32, device=device) + return torch.complex(real, imag) + if value_dtype == torch.complex128: + real = torch.randn(shape, dtype=torch.float64, device=device) + imag = torch.randn(shape, dtype=torch.float64, device=device) + return torch.complex(real, imag) + raise TypeError(f"Unsupported value dtype: {value_dtype}") + + +def _tolerance_for_dtype(value_dtype): + if value_dtype == torch.float16: + return 2e-3, 2e-3 + if value_dtype == torch.bfloat16: + return 1e-1, 1e-1 + if value_dtype in (torch.float32, torch.complex64): + return 1e-4, 1e-2 + if value_dtype in (torch.float64, torch.complex128): + return 1e-12, 1e-10 + return 1e-6, 1e-5 + + +def _scaled_allclose_error(candidate, reference, value_dtype=None): + if candidate.numel() == 0: + return 0.0 + dtype = reference.dtype if value_dtype is None else value_dtype + atol, rtol = _tolerance_for_dtype(dtype) + diff = torch.abs(candidate - reference) + denom = atol + rtol * torch.abs(reference) + return float(torch.max(diff / denom).item()) + +def load_mtx_to_coo_torch(file_path, dtype=torch.float32, device=None): + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(file_path, "r", encoding="utf-8") as handle: + lines = handle.readlines() + + mm_field = "real" + mm_symmetry = "general" + data_lines = [] + header_info = None + for line in lines: + line = line.strip() + if line.startswith("%%MatrixMarket"): + parts = line.split() + if len(parts) >= 5: + mm_field = parts[3].lower() + mm_symmetry = parts[4].lower() + continue + if line.startswith("%"): + continue + if not header_info and line: + parts = line.split() + n_rows = int(parts[0]) + n_cols = int(parts[1]) + nnz = int(parts[2]) if len(parts) > 2 else 0 + header_info = (n_rows, n_cols, nnz) + continue + if line: + data_lines.append(line) + + if header_info is None: + raise ValueError(f"Cannot parse .mtx header: {file_path}") + + n_rows, n_cols, nnz = header_info + if nnz == 0: + empty_index = torch.tensor([], dtype=torch.int64, device=device) + data = torch.tensor([], dtype=dtype, device=device) + return data, empty_index, empty_index.clone(), (n_rows, n_cols) + + if mm_field == "complex" and dtype not in (torch.complex64, torch.complex128): + raise TypeError( + f"Matrix Market file {file_path} stores complex values but requested dtype {dtype}" + ) + + is_pattern = mm_field == "pattern" + is_complex = mm_field == "complex" + is_symmetric = mm_symmetry == "symmetric" + is_skew = mm_symmetry == "skew-symmetric" + is_hermitian = mm_symmetry == "hermitian" + + entries = {} + + def _accumulate(row_idx, col_idx, value): + key = (row_idx, col_idx) + entries[key] = entries.get(key, 0.0) + value + + for line in data_lines[:nnz]: + parts = line.split() + if len(parts) < 2: + continue + row_idx = int(parts[0]) - 1 + col_idx = int(parts[1]) - 1 + if not (0 <= row_idx < n_rows and 0 <= col_idx < n_cols): + continue + + if is_pattern: + value = 1.0 + elif is_complex: + if len(parts) < 4: + raise ValueError(f"Complex Matrix Market entry is missing an imaginary part: {line}") + value = complex(float(parts[2]), float(parts[3])) + else: + if len(parts) < 3: + raise ValueError(f"Matrix Market entry is missing a numeric value: {line}") + value = float(parts[2]) + + _accumulate(row_idx, col_idx, value) + if row_idx != col_idx: + if is_symmetric and 0 <= col_idx < n_rows and 0 <= row_idx < n_cols: + _accumulate(col_idx, row_idx, value) + elif is_skew and 0 <= col_idx < n_rows and 0 <= row_idx < n_cols: + _accumulate(col_idx, row_idx, -value) + elif is_hermitian and 0 <= col_idx < n_rows and 0 <= row_idx < n_cols: + twin = value.conjugate() if isinstance(value, complex) else value + _accumulate(col_idx, row_idx, twin) + + sorted_entries = sorted(entries.items(), key=lambda item: item[0]) + rows = [key[0] for key, _ in sorted_entries] + cols = [key[1] for key, _ in sorted_entries] + vals = [value for _, value in sorted_entries] + + data = torch.tensor(vals, dtype=dtype, device=device) + row = torch.tensor(rows, dtype=torch.int64, device=device) + col = torch.tensor(cols, dtype=torch.int64, device=device) + return data, row, col, (n_rows, n_cols) + +def _normalize_route(route): + route = str(route).strip().lower() + if route not in ("rowrun", "atomic", "compare"): + raise ValueError("route must be one of: rowrun, atomic, compare") + return route + + + +def _selected_route(route): + route = _normalize_route(route) + return "rowrun" if route == "compare" else route + + + +def _route_label(route): + labels = { + "rowrun": "COO native row-run", + "atomic": "COO native atomic", + "compare": "COO native row-run (compare mode)", + } + if route not in labels: + raise ValueError(f"Unsupported route label: {route}") + return labels[route] + + + +def _empty_pairwise_summary(): + return { + "match": None, + "error_ratio": None, + "max_abs_error": None, + "max_relative_error": None, + "sum_relative_error": None, + } + + + +def _prepare_canonical_case(data, row, col, shape, B): + native_data, native_row, native_col, native_B, n_rows, n_cols, n_dense_cols = ast_ops._prepare_spmm_coo_inputs( + data, row, col, B, shape + ) + ( + canonical_data, + canonical_row, + canonical_col, + canonical_B, + _, + _, + _, + output_dtype, + _, + ) = ast_ops._prepare_spmm_coo_canonical_prepared( + native_data, + native_row, + native_col, + native_B, + n_rows, + n_cols, + n_dense_cols, + ) + native_coo = ast_ops._build_torch_sparse_coo(native_data, native_row, native_col, shape) + return { + "native_data": native_data, + "native_row": native_row, + "native_col": native_col, + "native_B": native_B, + "native_coo": native_coo, + "canonical_data": canonical_data, + "canonical_row": canonical_row, + "canonical_col": canonical_col, + "canonical_B": canonical_B, + "n_rows": n_rows, + "n_cols": n_cols, + "n_dense_cols": n_dense_cols, + "output_dtype": output_dtype, + } + + + +def _build_pytorch_reference(data, row, col, shape, B, prepared=None): + prepared = _prepare_canonical_case(data, row, col, shape, B) if prepared is None else prepared + expected = ast_ops._build_spmm_coo_pytorch_reference_from_canonical( + prepared["canonical_data"], + prepared["canonical_row"], + prepared["canonical_col"], + prepared["canonical_B"], + shape, + prepared["output_dtype"], + ) + pytorch_op = lambda: torch.sparse.mm(prepared["native_coo"], prepared["native_B"]) + return expected, pytorch_op, "COO", None + + + +def _benchmark_spmm_coo_route( + data, + row, + col, + B, + shape, + warmup, + iters, + route="rowrun", + block_n=None, + block_nnz=DEFAULT_BLOCK_NNZ, + prepared=None, +): + selected_route = _selected_route(route) + prepared = _prepare_canonical_case(data, row, col, shape, B) if prepared is None else prepared + return ast_ops._benchmark_spmm_coo_canonical_route( + prepared["canonical_data"], + prepared["canonical_row"], + prepared["canonical_col"], + prepared["canonical_B"], + prepared["n_rows"], + prepared["n_dense_cols"], + prepared["output_dtype"], + warmup, + iters, + block_n, + block_nnz, + selected_route, + ) + +def _summarize_route_output(values, reference, value_dtype, ms=None, first_call_ms=None, cusparse_values=None): + metrics = ast_ops._spmm_validation_metrics(values, reference) + atol, rtol = _tolerance_for_dtype(value_dtype) + summary = { + "ms": ms, + "first_call_ms": first_call_ms, + "ok_pt": torch.allclose(values, reference, atol=atol, rtol=rtol), + "err_pt": _scaled_allclose_error(values, reference, value_dtype), + "max_abs_error": metrics["max_abs_error"], + "max_relative_error": metrics["max_relative_error"], + "ok_cu": None, + "err_cu": None, + "error": None, + } + if cusparse_values is not None: + summary["ok_cu"] = torch.allclose(values, cusparse_values, atol=atol, rtol=rtol) + summary["err_cu"] = _scaled_allclose_error(values, cusparse_values, value_dtype) + return summary + + + +def _pairwise_route_summary(candidate, reference, value_dtype): + return ast_ops._spmm_coo_pairwise_summary(candidate, reference, value_dtype) + + + +def _format_debug_scalar(value): + if value is None: + return "-" + if torch.is_tensor(value): + value = value.item() + if isinstance(value, complex): + return f"{value.real:.16e}{value.imag:+.16e}j" + return f"{float(value):.16e}" + + + +def _build_compare_debug_summary(row, reference, route_outputs, cusparse_values, value_dtype): + if reference is None or reference.numel() == 0: + return None + + atol, rtol = _tolerance_for_dtype(value_dtype) + candidates = [] + for label in ("rowrun", "atomic"): + values = route_outputs.get(label) + if values is not None: + candidates.append((label, values)) + if cusparse_values is not None: + candidates.append(("cusparse", cusparse_values)) + + best = None + for label, candidate in candidates: + if candidate is None or candidate.shape != reference.shape or candidate.numel() == 0: + continue + diff = torch.abs(candidate - reference) + denom = atol + rtol * torch.abs(reference) + ratio = diff / denom + flat_idx = int(torch.argmax(ratio).item()) + error_ratio = float(ratio.reshape(-1)[flat_idx].item()) + if best is None or error_ratio > best["error_ratio"]: + row_idx = flat_idx // reference.shape[1] + dense_col = flat_idx % reference.shape[1] + best = { + "route": label, + "row": row_idx, + "dense_col": dense_col, + "error_ratio": error_ratio, + } + + if best is None: + return None + + row_idx = best["row"] + dense_col = best["dense_col"] + row64 = row.to(torch.int64) + row_nnz = int((row64 == row_idx).sum().item()) + + def _scalar_at(values): + if values is None: + return None + return values[row_idx, dense_col] + + return { + "route": best["route"], + "row": row_idx, + "dense_col": dense_col, + "row_nnz": row_nnz, + "error_ratio": best["error_ratio"], + "rowrun": _format_debug_scalar(_scalar_at(route_outputs.get("rowrun"))), + "atomic": _format_debug_scalar(_scalar_at(route_outputs.get("atomic"))), + "pt": _format_debug_scalar(reference[row_idx, dense_col]), + "cu": _format_debug_scalar(_scalar_at(cusparse_values)), + } +def _assert_spmm_coo_matches_reference(data, row, col, B, shape, value_dtype, out=None, block_n=None, block_nnz=DEFAULT_BLOCK_NNZ): + result = ast.flagsparse_spmm_coo( + data, + row, + col, + B, + shape, + block_n=block_n, + block_nnz=block_nnz, + out=out, + ) + ref_C, _, _, _ = _build_pytorch_reference(data, row, col, shape, B) + atol, rtol = _tolerance_for_dtype(value_dtype) + if not torch.allclose(result, ref_C, atol=atol, rtol=rtol): + metrics = ast_ops._spmm_validation_metrics(result, ref_C) + raise AssertionError( + "reference mismatch: " + f"err={_scaled_allclose_error(result, ref_C, value_dtype):.3e}, " + f"max_abs={metrics['max_abs_error']:.3e}, " + f"atol={atol:.3e}, " + f"rtol={rtol:.3e}" + ) + if out is not None and result.data_ptr() != out.data_ptr(): + raise AssertionError("flagsparse_spmm_coo did not return the provided out tensor") + return result, ref_C + + + +def _build_duplicate_unsorted_case(value_dtype, index_dtype, device, n_dense_cols=DUPLICATE_CASE_DENSE_COLS): + shape = (4, 6) + row = torch.tensor([2, 0, 2, 1, 2, 0, 3, 2], dtype=index_dtype, device=device) + col = torch.tensor([1, 4, 1, 3, 0, 4, 2, 5], dtype=index_dtype, device=device) + data = _build_values(row.numel(), value_dtype, device) + B = _build_dense_matrix(shape[1], n_dense_cols, value_dtype, device) + return data, row, col, B, shape + + + +def run_one_mtx( + mtx_path, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=10, + iters=50, + run_cusparse=True, + n_dense_cols=32, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + route="rowrun", +): + route = _normalize_route(route) + selected_route = _selected_route(route) + device = torch.device("cuda") + data, row, col, shape = load_mtx_to_coo_torch(mtx_path, dtype=value_dtype, device=device) + row = row.to(index_dtype) + col = col.to(index_dtype) + n_rows, n_cols = shape + nnz = data.numel() + B = _build_dense_matrix(n_cols, n_dense_cols, value_dtype, device) + prepared = None + atol, rtol = _tolerance_for_dtype(value_dtype) + + result = { + "path": mtx_path, + "shape": shape, + "nnz": nnz, + "dense_cols": n_dense_cols, + "route": selected_route, + "error": None, + "triton_ms": None, + "triton_first_call_ms": None, + "cusparse_ms": None, + "pytorch_ms": None, + "err_pt": None, + "err_cu": None, + "triton_abs_err": None, + "cusparse_abs_err": None, + "triton_relative_error_diag": None, + "cusparse_relative_error_diag": None, + "triton_ok_pt": None, + "triton_ok_cu": None, + "cusparse_reason": None, + "pytorch_reason": None, + "pytorch_format": None, + "status": "UNKNOWN", + "compare": None, + } + + try: + prepared = _prepare_canonical_case(data, row, col, shape, B) + ref_C, pytorch_op, pytorch_format, pytorch_reason = _build_pytorch_reference(data, row, col, shape, B, prepared=prepared) + result["pytorch_format"] = pytorch_format + result["pytorch_reason"] = pytorch_reason + except Exception as exc: + result["error"] = f"ref: {exc}" + result["status"] = "REF_FAIL" + return result + + triton_C = None + try: + triton_C, triton_ms, triton_first_call_ms = _benchmark_spmm_coo_route( + data, + row, + col, + B, + shape, + warmup, + iters, + route=selected_route, + block_n=block_n, + block_nnz=block_nnz, + prepared=prepared, + ) + result["triton_ms"] = triton_ms + result["triton_first_call_ms"] = triton_first_call_ms + except Exception as exc: + # Continue to PyTorch / CuPy timing when Triton fails (same as CSR SpMM test). + result["error"] = f"triton: {exc}" + result["triton_ok_pt"] = False + + if triton_C is not None: + triton_summary = _summarize_route_output(triton_C, ref_C, value_dtype) + result["triton_abs_err"] = triton_summary["max_abs_error"] + result["triton_relative_error_diag"] = triton_summary["max_relative_error"] + result["err_pt"] = triton_summary["err_pt"] + result["triton_ok_pt"] = triton_summary["ok_pt"] + else: + result["triton_ok_pt"] = False + + try: + _, result["pytorch_ms"] = ast_ops._benchmark_cuda_op( + pytorch_op, + warmup=warmup, + iters=iters, + ) + except Exception as exc: + reason = str(exc) + if result["pytorch_reason"]: + result["pytorch_reason"] = f"{result['pytorch_reason']}; timing: {reason}" + else: + result["pytorch_reason"] = reason + + cs_C_t = None + _cupy_supported_dtypes = ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ) + if run_cusparse: + if value_dtype not in _cupy_supported_dtypes: + result["cusparse_reason"] = "float16/bfloat16 not supported by CuPy sparse; skipped" + else: + try: + import cupy as cp + import cupyx.scipy.sparse as cpx + + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(prepared["native_data"])) + row_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(prepared["native_row"].to(torch.int64))) + col_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(prepared["native_col"].to(torch.int64))) + B_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(prepared["native_B"])) + A_coo = cpx.coo_matrix((data_cp, (row_cp, col_cp)), shape=shape) + + torch.cuda.synchronize() + for _ in range(warmup): + _ = A_coo @ B_cp + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + _ = A_coo @ B_cp + end.record() + torch.cuda.synchronize() + result["cusparse_ms"] = start.elapsed_time(end) / iters + + cs_C = A_coo @ B_cp + cs_C_t = torch.utils.dlpack.from_dlpack(cs_C.toDlpack()) + cusparse_metrics = ast_ops._spmm_validation_metrics(cs_C_t, ref_C) + result["cusparse_abs_err"] = cusparse_metrics["max_abs_error"] + result["cusparse_relative_error_diag"] = cusparse_metrics["max_relative_error"] + if triton_C is not None: + result["err_cu"] = _scaled_allclose_error(triton_C, cs_C_t, value_dtype) + result["triton_ok_cu"] = torch.allclose(triton_C, cs_C_t, atol=atol, rtol=rtol) + except Exception as exc: + result["cusparse_ms"] = None + result["err_cu"] = None + result["cusparse_abs_err"] = None + result["cusparse_relative_error_diag"] = None + result["triton_ok_cu"] = None + result["cusparse_reason"] = str(exc) + + if route == "compare": + route_outputs = {} + route_summaries = {} + if triton_C is not None: + route_outputs[selected_route] = triton_C + route_summaries[selected_route] = _summarize_route_output( + triton_C, + ref_C, + value_dtype, + ms=triton_ms, + first_call_ms=triton_first_call_ms, + cusparse_values=cs_C_t, + ) + for extra_route in ("rowrun", "atomic"): + if extra_route in route_outputs: + continue + try: + extra_C, extra_ms, extra_first_call_ms = _benchmark_spmm_coo_route( + data, + row, + col, + B, + shape, + warmup, + iters, + route=extra_route, + block_n=block_n, + block_nnz=block_nnz, + prepared=prepared, + ) + route_outputs[extra_route] = extra_C + route_summaries[extra_route] = _summarize_route_output( + extra_C, + ref_C, + value_dtype, + ms=extra_ms, + first_call_ms=extra_first_call_ms, + cusparse_values=cs_C_t, + ) + except Exception as exc: + route_summaries[extra_route] = { + "ms": None, + "first_call_ms": None, + "ok_pt": False, + "err_pt": None, + "max_abs_error": None, + "max_relative_error": None, + "ok_cu": None, + "err_cu": None, + "error": str(exc), + } + + parity = { + "rowrun_vs_atomic": _empty_pairwise_summary(), + } + if "rowrun" in route_outputs and "atomic" in route_outputs: + parity["rowrun_vs_atomic"] = _pairwise_route_summary(route_outputs["rowrun"], route_outputs["atomic"], value_dtype) + + cu_match = None if cs_C_t is None else torch.allclose(cs_C_t, ref_C, atol=atol, rtol=rtol) + compare_debug = None + rowrun_summary = route_summaries.get("rowrun") or {} + atomic_summary = route_summaries.get("atomic") or {} + if rowrun_summary.get("ok_pt") is False or atomic_summary.get("ok_pt") is False or cu_match is False: + compare_debug = _build_compare_debug_summary(prepared["canonical_row"], ref_C, route_outputs, cs_C_t, value_dtype) + + result["compare"] = { + "routes": route_summaries, + "parity": parity, + "cusparse_reference_match": cu_match, + "cusparse_reference_error": ( + None if cs_C_t is None else _scaled_allclose_error(cs_C_t, ref_C, value_dtype) + ), + "debug": compare_debug, + } + + result["status"] = "PASS" if (result["triton_ok_pt"] or result["triton_ok_cu"]) else "FAIL" + return result + + + +def run_mtx_batch( + paths, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=10, + iters=50, + run_cusparse=True, + n_dense_cols=32, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + route="rowrun", + on_result=None, +): + results = [] + for path in paths: + entry = run_one_mtx( + path, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + n_dense_cols=n_dense_cols, + block_n=block_n, + block_nnz=block_nnz, + route=route, + ) + results.append(entry) + if on_result is not None: + on_result(entry) + return results + + +def _print_spmm_coo_mtx_header(value_dtype, index_dtype, route): + route = _normalize_route(route) + print(f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)}") + print(f"Formats: FlagSparse={_route_label(route)}, cuSPARSE=COO dense-mm, PyTorch=COO.") + print("Timing stays in native dtype. For float32, correctness references use float64 compute then cast.") + print("PT/CU show per-reference correctness. Err(PT)/Err(CU)=max(|diff| / (atol + rtol*|ref|)).") + print("PyTorch uses COO sparse.mm as the only correctness reference path.") + if route == "compare": + print("Compare mode also benchmarks native atomic (debug-only) after the main table.") + print("-" * 186) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} {'DenseN':>8} " + f"{'FlagSparse(ms)':>14} {'cuSPARSE(ms)':>13} {'PyTorch(ms)':>11} " + f"{'FS/CU':>7} {'FS/PT':>7} {'PT':>6} {'CU':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 186) + + +def _print_spmm_coo_mtx_row(entry): + name = os.path.basename(entry["path"])[:27] + n_rows, n_cols = entry["shape"] + triton_ms = entry.get("triton_ms") + cu_ms = entry.get("cusparse_ms") + pt_ms = entry.get("pytorch_ms") + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {entry['nnz']:>10} {entry['dense_cols']:>8} " + f"{_fmt_ms(triton_ms):>14} {_fmt_ms(cu_ms):>13} {_fmt_ms(pt_ms):>11} " + f"{_fmt_speedup(cu_ms, triton_ms):>7} {_fmt_speedup(pt_ms, triton_ms):>7} " + f"{_fmt_check(entry.get('triton_ok_pt')):>6} {_fmt_check(entry.get('triton_ok_cu')):>6} " + f"{_fmt_err(entry.get('err_pt')):>10} {_fmt_err(entry.get('err_cu')):>10}" + ) + err = entry.get("error") + if err: + msg = str(err).replace("\n", " ") + if len(msg) > 200: + msg = msg[:197] + "..." + print(f" NOTE: {msg}") + + +def print_mtx_results(results, value_dtype, index_dtype, route="rowrun"): + route = _normalize_route(route) + _print_spmm_coo_mtx_header(value_dtype, index_dtype, route) + for entry in results: + _print_spmm_coo_mtx_row(entry) + print("-" * 186) + + + +def print_compare_results(results, value_dtype, index_dtype): + if not any(entry.get("compare") for entry in results): + return + + print("Compare details (PT-COO / CU-COO / native parity)") + print("Row/PT is the main default-route diagnostic; Atomic/PT is debug-only.") + print("-" * 166) + print( + f"{'Matrix':<28} {'Row/PT':>7} {'Atomic/PT':>9} {'CU/PT':>7} {'Row/Atomic':>11} " + f"{'Err(Row/PT)':>12} {'Err(Atomic/PT)':>14} {'Err(CU/PT)':>10} {'Err(Row/Atomic)':>15}" + ) + print("-" * 166) + for entry in results: + compare = entry.get("compare") or {} + routes = compare.get("routes") or {} + parity = compare.get("parity") or {} + rowrun = routes.get("rowrun") or {} + atomic = routes.get("atomic") or {} + row_atomic = parity.get("rowrun_vs_atomic") or {} + print( + f"{os.path.basename(entry['path'])[:27]:<28} " + f"{_fmt_check(rowrun.get('ok_pt')):>7} {_fmt_check(atomic.get('ok_pt')):>9} {_fmt_check(compare.get('cusparse_reference_match')):>7} " + f"{_fmt_check(row_atomic.get('match')):>11} " + f"{_fmt_err(rowrun.get('err_pt')):>12} {_fmt_err(atomic.get('err_pt')):>14} {_fmt_err(compare.get('cusparse_reference_error')):>10} " + f"{_fmt_err(row_atomic.get('error_ratio')):>15}" + ) + print("-" * 166) + + debug_rows = [] + for entry in results: + compare = entry.get("compare") or {} + debug = compare.get("debug") + if debug is not None: + debug_rows.append((os.path.basename(entry["path"])[:27], debug)) + if not debug_rows: + return + + print("Worst mismatch summary for failing compare cases") + print("-" * 178) + print( + f"{'Matrix':<28} {'Route':>8} {'Row':>8} {'DenseCol':>9} {'RowNNZ':>8} {'Err':>10} " + f"{'Rowrun':>18} {'Atomic':>18} {'PT':>18} {'CU':>18}" + ) + print("-" * 178) + for name, debug in debug_rows: + print( + f"{name:<28} {debug['route']:>8} {debug['row']:>8} {debug['dense_col']:>9} {debug['row_nnz']:>8} {debug['error_ratio']:>10.2e} " + f"{debug['rowrun']:>18} {debug['atomic']:>18} {debug['pt']:>18} {debug['cu']:>18}" + ) + print("-" * 178) +def run_all_dtypes_export_csv( + paths, + csv_path, + warmup=10, + iters=50, + run_cusparse=True, + n_dense_cols=32, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + route="rowrun", +): + route = _normalize_route(route) + if route == "compare": + raise ValueError("CSV export only supports route='rowrun' or route='atomic'") + selected_route = _selected_route(route) + csv_path = _normalize_csv_path(csv_path) + rows = [] + for value_dtype in CSV_VALUE_DTYPES: + for index_dtype in CSV_INDEX_DTYPES: + print("=" * 150) + _print_spmm_coo_mtx_header(value_dtype, index_dtype, route) + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + n_dense_cols=n_dense_cols, + block_n=block_n, + block_nnz=block_nnz, + route=selected_route, + on_result=_print_spmm_coo_mtx_row, + ) + print("-" * 186) + for entry in results: + n_rows, n_cols = entry["shape"] + rows.append({ + "matrix": os.path.basename(entry["path"]), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": entry["nnz"], + "triton_ms": entry.get("triton_ms"), + "cusparse_ms": entry.get("cusparse_ms"), + "pytorch_ms": entry.get("pytorch_ms"), + "pt_status": _status_label(entry.get("triton_ok_pt")), + "cu_status": _status_label(entry.get("triton_ok_cu")), + "status": ( + "PASS" + if (entry.get("triton_ok_pt") or entry.get("triton_ok_cu")) + else "FAIL" + ), + "err_pt": entry.get("err_pt"), + "err_cu": entry.get("err_cu"), + "error": entry.get("error"), + }) + fieldnames = [ + "matrix", "value_dtype", "index_dtype", "n_rows", "n_cols", "nnz", + "triton_ms", "cusparse_ms", "pytorch_ms", + "pt_status", "cu_status", "status", "err_pt", "err_cu", "error", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({key: ("" if value is None else value) for key, value in row.items()}) + print(f"Wrote {len(rows)} rows to {csv_path}") + +def run_api_validation_checks(): + if not torch.cuda.is_available(): + print("API checks skipped: CUDA is not available.") + return 0 + + device = torch.device("cuda") + data = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device=device) + row = torch.tensor([0, 0, 1], dtype=torch.int32, device=device) + col = torch.tensor([0, 1, 1], dtype=torch.int32, device=device) + B = torch.randn((2, 4), dtype=torch.float32, device=device) + dup_data, dup_row, dup_col, dup_B, dup_shape = _build_duplicate_unsorted_case( + torch.float32, torch.int32, device + ) + + negative_cases = [ + ("shape must be 2D", lambda: ast.flagsparse_spmm_coo(data, row, col, B, (2,)), ValueError), + ("B must be 2D", lambda: ast.flagsparse_spmm_coo(data, row, col, B[0], (2, 2)), ValueError), + ("dtype mismatch", lambda: ast.flagsparse_spmm_coo(data, row, col, B.to(torch.float64), (2, 2)), TypeError), + ("shape mismatch", lambda: ast.flagsparse_spmm_coo(data, row, col, torch.randn((3, 4), dtype=torch.float32, device=device), (2, 2)), ValueError), + ("row length mismatch", lambda: ast.flagsparse_spmm_coo(data, row[:-1], col, B, (2, 2)), ValueError), + ("col length mismatch", lambda: ast.flagsparse_spmm_coo(data, row, col[:-1], B, (2, 2)), ValueError), + ("row out of range", lambda: ast.flagsparse_spmm_coo(data, torch.tensor([0, 2, 1], dtype=torch.int32, device=device), col, B, (2, 2)), IndexError), + ("col out of range", lambda: ast.flagsparse_spmm_coo(data, row, torch.tensor([0, 3, 1], dtype=torch.int32, device=device), B, (2, 2)), IndexError), + ("block_n positive", lambda: ast.flagsparse_spmm_coo(data, row, col, B, (2, 2), block_n=0), ValueError), + ("block_nnz positive", lambda: ast.flagsparse_spmm_coo(data, row, col, B, (2, 2), block_nnz=0), ValueError), + ("out shape mismatch", lambda: ast.flagsparse_spmm_coo(data, row, col, B, (2, 2), out=torch.empty((3, 4), dtype=torch.float32, device=device)), ValueError), + ("out device mismatch", lambda: ast.flagsparse_spmm_coo(data, row, col, B, (2, 2), out=torch.empty((2, 4), dtype=torch.float32)), ValueError), + ] + + failed = 0 + print("-" * 96) + print("API validation checks") + print("-" * 96) + for name, fn, exc_type in negative_cases: + try: + fn() + print(f"FAIL {name:<32} expected {exc_type.__name__}") + failed += 1 + except exc_type: + print(f"PASS {name:<32} raised {exc_type.__name__}") + except Exception as exc: + print(f"FAIL {name:<32} raised {type(exc).__name__}: {exc}") + failed += 1 + + positive_checks = [] + + def _positive_out_path(): + out = torch.empty((2, 4), dtype=torch.float32, device=device) + _assert_spmm_coo_matches_reference(data, row, col, B, (2, 2), torch.float32, out=out) + + positive_checks.append(("out path success", _positive_out_path)) + + def _positive_empty_matrix(): + empty_data = torch.tensor([], dtype=torch.float32, device=device) + empty_row = torch.tensor([], dtype=torch.int32, device=device) + empty_col = torch.tensor([], dtype=torch.int32, device=device) + dense = torch.randn((2, 4), dtype=torch.float32, device=device) + result, _ = _assert_spmm_coo_matches_reference( + empty_data, + empty_row, + empty_col, + dense, + (2, 2), + torch.float32, + ) + if result.shape != (2, 4): + raise AssertionError(f"unexpected empty-matrix result shape: {tuple(result.shape)}") + + positive_checks.append(("empty matrix success", _positive_empty_matrix)) + + def _positive_empty_dense_cols(): + dense = torch.empty((2, 0), dtype=torch.float32, device=device) + result, _ = _assert_spmm_coo_matches_reference( + data, + row, + col, + dense, + (2, 2), + torch.float32, + ) + if result.shape != (2, 0): + raise AssertionError(f"unexpected empty-dense result shape: {tuple(result.shape)}") + + positive_checks.append(("empty dense cols success", _positive_empty_dense_cols)) + + def _positive_noncontiguous_b(): + dense = _build_dense_matrix(4, 2, torch.float32, device).transpose(0, 1) + if dense.is_contiguous(): + raise AssertionError("expected non-contiguous test matrix") + _assert_spmm_coo_matches_reference(data, row, col, dense, (2, 2), torch.float32) + + positive_checks.append(("noncontiguous B success", _positive_noncontiguous_b)) + + def _positive_unsorted_duplicate(): + _assert_spmm_coo_matches_reference( + dup_data, + dup_row, + dup_col, + dup_B, + dup_shape, + torch.float32, + ) + + positive_checks.append(("unsorted duplicate success", _positive_unsorted_duplicate)) + + for name, fn in positive_checks: + try: + fn() + print(f"PASS {name:<32} returned correct result") + except Exception as exc: + print(f"FAIL {name:<32} raised {type(exc).__name__}: {exc}") + failed += 1 + + print("-" * 96) + return failed + +def run_coo_tile_branch_coverage(warmup=WARMUP, iters=ITERS, run_cusparse=True): + if not torch.cuda.is_available(): + print("COO branch coverage skipped: CUDA is not available.") + return 0 + + print("=" * 144) + print("COO native row-run dense-column coverage") + print("=" * 144) + print( + f"{'DenseN':>8} {'BLOCK_N':>8} {'NNZTile':>8} {'Runs':>7} {'Tiles':>7} {'Warp':>6} {'Factor':>7} " + f"{'PyTorch(ms)':>12} {'FlagSparse(ms)':>14} {'cuSPARSE(ms)':>12} {'PT':>6} {'CU':>6} {'Err(FS)':>11}" + ) + print("-" * 144) + + failed = 0 + note = None + for n_rows, n_cols, nnz, n_dense_cols in COO_TILE_CASES: + result = ast_ops.benchmark_spmm_coo_case( + n_rows=n_rows, + n_cols=n_cols, + nnz=nnz, + n_dense_cols=n_dense_cols, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=warmup, + iters=iters, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + run_cusparse=run_cusparse, + ) + params = result["parameters"] + perf = result["performance"] + verify = result["verification"] + backend = result["backend_status"] + samples = result["samples"] + triton_ok = verify.get("triton_strict_allclose_match", verify.get("triton_match_reference")) + cusparse_ok = verify.get("cusparse_strict_allclose_match", verify.get("cusparse_match_reference")) + status = "PASS" if triton_ok else "FAIL" + if status != "PASS": + failed += 1 + if backend.get("cusparse_unavailable_reason"): + note = backend["cusparse_unavailable_reason"] + triton_err = _scaled_allclose_error(samples["triton"], samples["reference"], torch.float32) + print( + f"{n_dense_cols:>8} {params['block_n']:>8} {params['block_nnz']:>8} {params['n_row_runs']:>7} {params['required_nnz_tiles']:>7} {params['heuristic_warp_size']:>6} {params['heuristic_factor']:>7} " + f"{_fmt_ms(perf.get('pytorch_ms')):>12} {_fmt_ms(perf.get('triton_ms')):>14} {_fmt_ms(perf.get('cusparse_ms')):>12} " + f"{_fmt_check(triton_ok):>6} {_fmt_check(cusparse_ok):>6} {_fmt_err(triton_err):>11}" + ) + print("-" * 144) + if note: + print(f"cuSPARSE note: {note}") + print() + return failed + + + +def _print_synthetic_compare_results(compare_rows): + if not compare_rows: + return + + print("Compare details (PT-COO / CU-COO / native parity)") + print("Row/PT is the main default-route diagnostic; Atomic/PT is debug-only.") + print("-" * 160) + print( + f"{'N_rows':>7} {'N_cols':>7} {'NNZ':>10} {'DenseN':>8} {'Row/PT':>7} {'Atomic/PT':>9} {'CU/PT':>7} {'Row/Atomic':>11} " + f"{'Err(Row/PT)':>12} {'Err(Atomic/PT)':>14} {'Err(CU/PT)':>10} {'Err(Row/Atomic)':>15}" + ) + print("-" * 160) + for entry in compare_rows: + print( + f"{entry['n_rows']:>7} {entry['n_cols']:>7} {entry['nnz']:>10} {entry['dense_cols']:>8} " + f"{_fmt_check(entry.get('row_pt')):>7} {_fmt_check(entry.get('atomic_pt')):>9} {_fmt_check(entry.get('cu_pt')):>7} " + f"{_fmt_check(entry.get('row_atomic')):>11} " + f"{_fmt_err(entry.get('err_row_pt')):>12} {_fmt_err(entry.get('err_atomic_pt')):>14} {_fmt_err(entry.get('err_cu_pt')):>10} " + f"{_fmt_err(entry.get('err_row_atomic')):>15}" + ) + print("-" * 160) + print() +def run_comprehensive_synthetic( + warmup=WARMUP, + iters=ITERS, + run_cusparse=True, + run_api_checks=True, + run_coo_coverage=True, + block_n=DEFAULT_BLOCK_N, + block_nnz=DEFAULT_BLOCK_NNZ, + route="rowrun", +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + + route = _normalize_route(route) + selected_route = _selected_route(route) + + print("=" * 150) + print("FLAGSPARSE SpMM BENCHMARK (synthetic COO @ dense)") + print("=" * 150) + print( + f"GPU: {torch.cuda.get_device_name(0)} | Warmup: {warmup} Iters: {iters} " + f"BLOCK_N: {_fmt_launch_value(block_n)} BLOCK_NNZ: {_fmt_launch_value(block_nnz)} Route: {route}" + ) + print(f"Formats: FlagSparse={_route_label(route)}, cuSPARSE=COO dense-mm (when supported), PyTorch=COO.") + print("For float32, PT checks the float64-based correctness reference while CU reflects native cuSPARSE float32 consistency.") + if route == "compare": + print("Compare mode also benchmarks native atomic (debug-only) for each synthetic case.") + print() + + total = 0 + failed = 0 + for value_dtype in VALUE_DTYPES: + for index_dtype in INDEX_DTYPES: + compare_rows = [] + print("-" * 150) + print(f"Value dtype: {_dtype_name(value_dtype):<12} | Index dtype: {_dtype_name(index_dtype):<6}") + print("-" * 150) + print( + f"{'N_rows':>7} {'N_cols':>7} {'NNZ':>10} {'DenseN':>8} {'BN':>4} {'BNNZ':>6} {'Runs':>5} {'Tiles':>5} " + f"{'PyTorch(ms)':>12} {'FlagSparse(ms)':>14} {'cuSPARSE(ms)':>12} {'FS/PT':>8} {'FS/CU':>8} {'PT':>6} {'CU':>6} {'Err(FS)':>11} {'Err(CU)':>12}" + ) + print("-" * 150) + combo_reason = None + for n_rows, n_cols, nnz, n_dense_cols in TEST_CASES: + result = ast_ops.benchmark_spmm_coo_case( + n_rows=n_rows, + n_cols=n_cols, + nnz=nnz, + n_dense_cols=n_dense_cols, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + block_n=block_n, + block_nnz=block_nnz, + run_cusparse=run_cusparse, + route=selected_route, + compare_routes=(route == "compare"), + ) + total += 1 + params = result["parameters"] + perf = result["performance"] + verify = result["verification"] + backend = result["backend_status"] + samples = result["samples"] + triton_ok = verify.get("triton_strict_allclose_match", verify.get("triton_match_reference")) + cusparse_ok = verify.get("cusparse_strict_allclose_match", verify.get("cusparse_match_reference")) + status = "PASS" if triton_ok else "FAIL" + if status != "PASS": + failed += 1 + if backend.get("cusparse_unavailable_reason"): + combo_reason = backend["cusparse_unavailable_reason"] + triton_err = _scaled_allclose_error(samples["triton"], samples["reference"], value_dtype) + cusparse_err = None + if samples.get("cusparse") is not None: + cusparse_err = _scaled_allclose_error(samples["triton"], samples["cusparse"], value_dtype) + print( + f"{n_rows:>7} {n_cols:>7} {nnz:>10} {n_dense_cols:>8} {params['block_n']:>4} {params['block_nnz']:>6} {params['n_row_runs']:>5} {params['required_nnz_tiles']:>5} " + f"{_fmt_ms(perf.get('pytorch_ms')):>12} {_fmt_ms(perf.get('triton_ms')):>14} {_fmt_ms(perf.get('cusparse_ms')):>12} " + f"{_fmt_speedup(perf.get('pytorch_ms'), perf.get('triton_ms')):>8} {_fmt_speedup(perf.get('cusparse_ms'), perf.get('triton_ms')):>8} " + f"{_fmt_check(triton_ok):>6} {_fmt_check(cusparse_ok):>6} {_fmt_err(triton_err):>11} {_fmt_err(cusparse_err):>12}" + ) + if route == "compare": + route_results = result.get("route_results") or {} + parity = result.get("parity") or {} + compare_rows.append({ + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": nnz, + "dense_cols": n_dense_cols, + "row_pt": (route_results.get("rowrun") or {}).get("match_reference"), + "atomic_pt": (route_results.get("atomic") or {}).get("match_reference"), + "cu_pt": verify.get("cusparse_match_reference"), + "row_atomic": (parity.get("rowrun_vs_atomic") or {}).get("match"), + "err_row_pt": (route_results.get("rowrun") or {}).get("error_ratio"), + "err_atomic_pt": (route_results.get("atomic") or {}).get("error_ratio"), + "err_cu_pt": (verify.get("cusparse_max_relative_error") if verify.get("cusparse_match_reference") is not None else None), + "err_row_atomic": (parity.get("rowrun_vs_atomic") or {}).get("error_ratio"), + }) + print("-" * 150) + if combo_reason: + print(f" cuSPARSE: {combo_reason}") + print() + if route == "compare": + _print_synthetic_compare_results(compare_rows) + + coo_failed = 0 + if run_coo_coverage: + if route == "rowrun": + coo_failed = run_coo_tile_branch_coverage(warmup=warmup, iters=iters, run_cusparse=run_cusparse) + else: + print(f"COO dense-column coverage is row-run specific; skipped for route {route}.") + print() + api_failed = run_api_validation_checks() if run_api_checks else 0 + print("=" * 150) + print( + f"Total synthetic cases: {total} Failed synthetic cases: {failed} " + f"Failed COO branch cases: {coo_failed} Failed API checks: {api_failed}" + ) + print("=" * 150) + + +def main(): + parser = argparse.ArgumentParser(description="COO SpMM test: SuiteSparse .mtx batch run, error and performance.") + parser.add_argument("mtx", nargs="*", help=".mtx file path(s), or directory(ies) to glob for *.mtx") + parser.add_argument("--synthetic", action="store_true", help="Run synthetic benchmark instead of .mtx") + parser.add_argument("--dtype", default="float32", choices=["float16", "bfloat16", "float32", "float64", "complex64", "complex128"], help="Value dtype (default: float32)") + parser.add_argument("--index-dtype", default="int32", choices=["int32", "int64"], help="Index dtype (default: int32)") + parser.add_argument("--dense-cols", type=int, default=32, help="Dense RHS column count") + parser.add_argument("--block-n", type=int, default=DEFAULT_BLOCK_N, help="Output column tile override (default: auto from dense-column heuristic)") + parser.add_argument("--block-nnz", type=int, default=DEFAULT_BLOCK_NNZ, help="COO nnz tile width override (default: 256)") + parser.add_argument("--route", default="rowrun", choices=["rowrun", "atomic", "compare"], help="Native COO route to benchmark/test (default: rowrun)") + parser.add_argument("--warmup", type=int, default=10, help="Warmup runs") + parser.add_argument("--iters", type=int, default=50, help="Timing iterations") + parser.add_argument("--no-cusparse", action="store_true", help="Skip cuSPARSE baseline") + parser.add_argument("--skip-api-checks", action="store_true", help="Skip API validation checks in synthetic mode") + parser.add_argument("--skip-coo-coverage", action="store_true", help="Skip dense-column COO heuristic coverage in synthetic mode") + parser.add_argument("--csv", type=str, default=None, metavar="FILE", help="Run float32/float64 with int32 indices on all .mtx and write results to one CSV") + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "float64": torch.float64, + "complex64": torch.complex64, + "complex128": torch.complex128, + } + index_map = {"int32": torch.int32, "int64": torch.int64} + value_dtype = dtype_map[args.dtype] + index_dtype = index_map[args.index_dtype] + + if args.synthetic: + run_comprehensive_synthetic( + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + run_api_checks=not args.skip_api_checks, + run_coo_coverage=not args.skip_coo_coverage, + block_n=args.block_n, + block_nnz=args.block_nnz, + route=args.route, + ) + return + + paths = [] + for path in args.mtx: + if os.path.isfile(path) and path.endswith(".mtx"): + paths.append(path) + elif os.path.isdir(path): + paths.extend(sorted(glob.glob(os.path.join(path, "*.mtx")))) + + if not paths and not args.csv: + print("No .mtx files given. Use: python test_spmm_coo.py [file2.mtx ...] or ") + print("Or run synthetic: python test_spmm_coo.py --synthetic") + print("Or run all dtypes and export CSV: python test_spmm_coo.py --csv results.csv") + return + + if args.csv is not None: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found. Specify files or a directory.") + return + if args.route == "compare": + print("CSV export only supports --route rowrun or --route atomic.") + return + csv_path = _normalize_csv_path(args.csv) + print("=" * 100) + print("FLAGSPARSE COO SpMM - f32/f64 with int32, export to CSV") + print("=" * 100) + print(f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)} | DenseN: {args.dense_cols} | Route: {args.route} | CSV: {csv_path}") + if args.dtype != "float32" or args.index_dtype != "int32": + print("Note: --csv export ignores --dtype/--index-dtype and always writes float32/float64 with int32 indices.") + run_all_dtypes_export_csv( + paths, + csv_path, + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + n_dense_cols=args.dense_cols, + block_n=args.block_n, + block_nnz=args.block_nnz, + route=args.route, + ) + return + + print("=" * 140) + print("FLAGSPARSE COO SpMM - SuiteSparse .mtx batch (error + performance)") + print("=" * 140) + print(f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)}") + print( + f"dtype: {args.dtype} index_dtype: {args.index_dtype} dense_cols: {args.dense_cols} " + f"warmup: {args.warmup} iters: {args.iters} block_n: {_fmt_launch_value(args.block_n)} " + f"block_nnz: {_fmt_launch_value(args.block_nnz)} route: {args.route}" + ) + print() + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + n_dense_cols=args.dense_cols, + block_n=args.block_n, + block_nnz=args.block_nnz, + route=args.route, + ) + print_mtx_results(results, value_dtype, index_dtype, route=args.route) + if args.route == "compare": + print_compare_results(results, value_dtype, index_dtype) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_spmm_opt.py b/tests/test_spmm_opt.py new file mode 100644 index 0000000..c340b52 --- /dev/null +++ b/tests/test_spmm_opt.py @@ -0,0 +1,449 @@ +""" +SpMM opt A/B test: compare base vs opt side-by-side with PyTorch and cuSPARSE timings. + +Usage: + python tests/test_spmm_opt.py --dense-cols 32 + python tests/test_spmm_opt.py --csv spmm_opt.csv +""" + +import argparse +import csv +import glob +import os +import sys +from pathlib import Path + +import torch + +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + +import flagsparse as fs + +VALUE_DTYPES = [torch.float32, torch.float64] +INDEX_DTYPES = [torch.int32] +WARMUP = 10 +ITERS = 50 +DEFAULT_DENSE_COLS = 32 +DEFAULT_SEED = None + + +def load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(file_path, "r", encoding="utf-8") as handle: + lines = handle.readlines() + + mm_field = "real" + mm_symmetry = "general" + data_lines = [] + header_info = None + for line in lines: + stripped = line.strip() + if stripped.startswith("%%MatrixMarket"): + tokens = stripped.split() + if len(tokens) >= 5: + mm_field = tokens[3].lower() + mm_symmetry = tokens[4].lower() + continue + if stripped.startswith("%"): + continue + if not header_info and stripped: + parts = stripped.split() + n_rows = int(parts[0]) + n_cols = int(parts[1]) + nnz = int(parts[2]) if len(parts) > 2 else 0 + header_info = (n_rows, n_cols, nnz) + continue + if stripped: + data_lines.append(stripped) + if header_info is None: + raise ValueError(f"Cannot parse .mtx header: {file_path}") + + n_rows, n_cols, nnz = header_info + if nnz == 0: + data = torch.tensor([], dtype=dtype, device=device) + indices = torch.tensor([], dtype=torch.int64, device=device) + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + + is_pattern = mm_field == "pattern" + is_symmetric = mm_symmetry in ("symmetric", "hermitian") + is_skew = mm_symmetry == "skew-symmetric" + row_maps = [dict() for _ in range(n_rows)] + for line in data_lines[:nnz]: + parts = line.split() + row = int(parts[0]) - 1 + col = int(parts[1]) - 1 + value = 1.0 if is_pattern else float(parts[2]) + if 0 <= row < n_rows and 0 <= col < n_cols: + row_maps[row][col] = row_maps[row].get(col, 0.0) + value + if row != col: + if is_symmetric and 0 <= col < n_rows and 0 <= row < n_cols: + row_maps[col][row] = row_maps[col].get(row, 0.0) + value + elif is_skew and 0 <= col < n_rows and 0 <= row < n_cols: + row_maps[col][row] = row_maps[col].get(row, 0.0) - value + + cols_sorted = [] + vals_sorted = [] + indptr_list = [0] + for row in range(n_rows): + row_map = row_maps[row] + for col in sorted(row_map.keys()): + cols_sorted.append(col) + vals_sorted.append(row_map[col]) + indptr_list.append(len(cols_sorted)) + + data = torch.tensor(vals_sorted, dtype=dtype, device=device) + indices = torch.tensor(cols_sorted, dtype=torch.int64, device=device) + indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + + +def _timed_spmm_base(data, indices, indptr, B, shape, warmup, iters): + op = lambda: fs.flagsparse_spmm_csr(data, indices, indptr, B, shape) + out = op() + torch.cuda.synchronize() + for _ in range(warmup): + out = op() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + out = op() + end.record() + torch.cuda.synchronize() + return out, start.elapsed_time(end) / iters + + +def _timed_spmm_opt(data, indices, indptr, B, shape, warmup, iters): + prepared = fs.prepare_spmm_csr_opt(data, indices, indptr, shape) + op = lambda: fs.flagsparse_spmm_csr_opt(B=B, prepared=prepared) + out = op() + torch.cuda.synchronize() + for _ in range(warmup): + out = op() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + out = op() + end.record() + torch.cuda.synchronize() + return out, start.elapsed_time(end) / iters + + +def _timed_pytorch(data, indices, indptr, B, shape, warmup, iters): + device = data.device + try: + sparse = torch.sparse_csr_tensor( + indptr.to(torch.int64), + indices.to(torch.int64), + data, + size=shape, + device=device, + ) + except Exception: + n_rows = int(shape[0]) + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + sparse = torch.sparse_coo_tensor( + torch.stack([row_ind, indices.to(torch.int64)]), + data, + shape, + device=device, + ).coalesce() + op = lambda: torch.sparse.mm(sparse, B) + out = op() + torch.cuda.synchronize() + for _ in range(warmup): + out = op() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + out = op() + end.record() + torch.cuda.synchronize() + return out, start.elapsed_time(end) / iters + + +def _timed_cusparse(data, indices, indptr, B, shape, warmup, iters): + import cupy as cp + import cupyx.scipy.sparse as cpx + + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data)) + ind_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64))) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr)) + B_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(B)) + sparse = cpx.csr_matrix((data_cp, ind_cp, ptr_cp), shape=shape) + torch.cuda.synchronize() + for _ in range(warmup): + _ = sparse @ B_cp + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + _ = sparse @ B_cp + end.record() + torch.cuda.synchronize() + out_cp = sparse @ B_cp + out = torch.utils.dlpack.from_dlpack(out_cp.toDlpack()) + return out, start.elapsed_time(end) / iters + + +def _build_reference(data, indices, indptr, B, shape, dtype): + device = data.device + ref_dtype = torch.float64 if dtype == torch.float32 else dtype + sparse = torch.sparse_csr_tensor( + indptr.to(torch.int64), + indices.to(torch.int64), + data.to(ref_dtype), + size=shape, + device=device, + ) + return torch.sparse.mm(sparse, B.to(ref_dtype)).to(dtype) + + +def _error_ratio(candidate, reference, dtype): + if dtype == torch.float32: + atol, rtol = 1e-4, 1e-2 + else: + atol, rtol = 1e-12, 1e-10 + if candidate.numel() == 0: + return 0.0 + diff = torch.abs(candidate - reference).to(torch.float64) + denom = (atol + rtol * torch.abs(reference)).to(torch.float64) + return float(torch.max(diff / denom).item()) + + +def _fmt(v): + return "N/A" if v is None else f"{v:.4f}" + + +def _spd(base, other): + if base is None or other is None or other <= 0: + return "N/A" + return f"{base / other:.2f}x" + + +def _err(v): + return "N/A" if v is None else f"{v:.2e}" + + +HEADER = ( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} {'DenseN':>8} " + f"{'Base(ms)':>9} {'Opt(ms)':>9} {'PT(ms)':>9} {'CU(ms)':>9} " + f"{'Opt/Base':>8} {'Opt/PT':>8} {'Opt/CU':>8} " + f"{'Err(Base)':>10} {'Err(Opt)':>10} {'Status':>6}" +) +SEP = "-" * 182 + + +def _seeded_dense_matrix(shape, dtype, device, seed): + if seed is None: + return torch.randn(shape, dtype=dtype, device=device) + torch.manual_seed(int(seed)) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(int(seed)) + return torch.randn(shape, dtype=dtype, device=device) + + +def run_one_mtx(path, dtype, index_dtype, dense_cols, warmup, iters, seed=None): + device = torch.device("cuda") + data, indices, indptr, shape = load_mtx_to_csr_torch(path, dtype=dtype, device=device) + indices = indices.to(index_dtype) + n_rows, n_cols = shape + nnz = data.numel() + B = _seeded_dense_matrix((n_cols, dense_cols), dtype, device, seed) + ref = _build_reference(data, indices, indptr, B, shape, dtype) + + y_base, base_ms = _timed_spmm_base(data, indices, indptr, B, shape, warmup, iters) + y_opt, opt_ms = _timed_spmm_opt(data, indices, indptr, B, shape, warmup, iters) + + pt_ms = None + try: + _, pt_ms = _timed_pytorch(data, indices, indptr, B, shape, warmup, iters) + except Exception: + pass + + cu_ms = None + try: + _, cu_ms = _timed_cusparse(data, indices, indptr, B, shape, warmup, iters) + except Exception: + pass + + err_base = _error_ratio(y_base, ref, dtype) + err_opt = _error_ratio(y_opt, ref, dtype) + base_ok = err_base <= 1.0 + opt_ok = err_opt <= 1.0 + status = "PASS" if opt_ok else "FAIL" + return { + "path": path, + "shape": shape, + "nnz": nnz, + "dense_cols": dense_cols, + "base_ms": base_ms, + "opt_ms": opt_ms, + "pt_ms": pt_ms, + "cu_ms": cu_ms, + "err_base": err_base, + "err_opt": err_opt, + "base_ok": base_ok, + "opt_ok": opt_ok, + "seed": seed, + "status": status, + } + + +def print_row(row): + name = os.path.basename(row["path"])[:27] + n_rows, n_cols = row["shape"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {row['nnz']:>10} {row['dense_cols']:>8} " + f"{_fmt(row['base_ms']):>9} {_fmt(row['opt_ms']):>9} " + f"{_fmt(row['pt_ms']):>9} {_fmt(row['cu_ms']):>9} " + f"{_spd(row['base_ms'], row['opt_ms']):>8} " + f"{_spd(row['pt_ms'], row['opt_ms']):>8} " + f"{_spd(row['cu_ms'], row['opt_ms']):>8} " + f"{_err(row['err_base']):>10} {_err(row['err_opt']):>10} {row['status']:>6}" + ) + + +def run_batch(paths, dtype, index_dtype, dense_cols, warmup, iters, seed=None): + results = [] + for path in paths: + try: + row = run_one_mtx(path, dtype, index_dtype, dense_cols, warmup, iters, seed=seed) + except Exception as exc: + print(f" ERROR on {os.path.basename(path)}: {exc}") + continue + results.append(row) + print_row(row) + return results + + +def run_all_csv(paths, csv_path, dense_cols, warmup, iters, seed=None): + rows = [] + for dtype in VALUE_DTYPES: + for index_dtype in INDEX_DTYPES: + dname = str(dtype).replace("torch.", "") + iname = str(index_dtype).replace("torch.", "") + print("=" * 182) + print(f"Value dtype: {dname} | Index dtype: {iname} | Dense cols: {dense_cols}") + print( + "Base = existing CSR SpMM baseline (fp64-accum for fp32). " + "Opt = bucketed CSR SpMM native path. " + "Speedup = Base/Opt or Ref/Opt." + ) + print(SEP) + print(HEADER) + print(SEP) + results = run_batch(paths, dtype, index_dtype, dense_cols, warmup, iters, seed=seed) + print(SEP) + for row in results: + n_rows, n_cols = row["shape"] + rows.append({ + "matrix": os.path.basename(row["path"]), + "value_dtype": dname, + "index_dtype": iname, + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": row["nnz"], + "dense_cols": row["dense_cols"], + "seed": row["seed"], + "base_ms": row["base_ms"], + "opt_ms": row["opt_ms"], + "pt_ms": row["pt_ms"], + "cu_ms": row["cu_ms"], + "opt_vs_base": (row["base_ms"] / row["opt_ms"] if row["opt_ms"] and row["opt_ms"] > 0 else None), + "opt_vs_pt": (row["pt_ms"] / row["opt_ms"] if row["pt_ms"] and row["opt_ms"] and row["opt_ms"] > 0 else None), + "opt_vs_cu": (row["cu_ms"] / row["opt_ms"] if row["cu_ms"] and row["opt_ms"] and row["opt_ms"] > 0 else None), + "err_base": row["err_base"], + "err_opt": row["err_opt"], + "status": row["status"], + }) + fields = [ + "matrix", + "value_dtype", + "index_dtype", + "n_rows", + "n_cols", + "nnz", + "dense_cols", + "seed", + "base_ms", + "opt_ms", + "pt_ms", + "cu_ms", + "opt_vs_base", + "opt_vs_pt", + "opt_vs_cu", + "err_base", + "err_opt", + "status", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=fields, extrasaction="ignore") + writer.writeheader() + for row in rows: + writer.writerow({key: ("" if value is None else value) for key, value in row.items()}) + print(f"\nWrote {len(rows)} rows to {csv_path}") + + +def main(): + parser = argparse.ArgumentParser(description="SpMM opt A/B: baseline vs optimised, with PyTorch/cuSPARSE timings.") + parser.add_argument("mtx", nargs="*", help=".mtx files or directories") + parser.add_argument("--csv", type=str, default=None, metavar="FILE", help="Export all dtypes to CSV") + parser.add_argument("--dtype", default="float32", choices=["float32", "float64"]) + parser.add_argument("--dense-cols", type=int, default=DEFAULT_DENSE_COLS) + parser.add_argument("--warmup", type=int, default=WARMUP) + parser.add_argument("--iters", type=int, default=ITERS) + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Optional fixed seed for reproducible dense RHS generation") + args = parser.parse_args() + + paths = [] + for path in args.mtx: + if os.path.isfile(path) and path.endswith(".mtx"): + paths.append(path) + elif os.path.isdir(path): + paths.extend(sorted(glob.glob(os.path.join(path, "*.mtx")))) + if not paths: + print("No .mtx files. Usage: python test_spmm_opt.py [--csv out.csv]") + return + + if args.csv: + run_all_csv(paths, args.csv, args.dense_cols, args.warmup, args.iters, seed=args.seed) + return + + dtype_map = {"float32": torch.float32, "float64": torch.float64} + dtype = dtype_map[args.dtype] + print("=" * 182) + print("FLAGSPARSE SpMM Optimisation A/B Test") + print(f"GPU: {torch.cuda.get_device_name(0)} | dtype: {args.dtype} | Dense cols: {args.dense_cols} | Files: {len(paths)}") + if args.seed is not None: + print(f"Seed: {args.seed}") + print( + "Base = existing CSR SpMM baseline (fp64-accum for fp32). " + "Opt = bucketed CSR SpMM native path. " + "Speedup = Base/Opt or Ref/Opt." + ) + print(SEP) + print(HEADER) + print(SEP) + results = run_batch(paths, dtype, torch.int32, args.dense_cols, args.warmup, args.iters, seed=args.seed) + print(SEP) + passed = sum(1 for row in results if row["status"] == "PASS") + print(f"Passed: {passed} / {len(results)}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_spmv.py b/tests/test_spmv.py new file mode 100644 index 0000000..6349427 --- /dev/null +++ b/tests/test_spmv.py @@ -0,0 +1,693 @@ +""" +SpMV tests (CSR): load SuiteSparse .mtx, batch run, output error and performance. +Supports: multi .mtx files, value_dtype / index_dtype, --csv-csr to run all dtypes and export CSV. +""" +import argparse +import csv +import glob +import math +import os + +import torch +import flagsparse as ast + +VALUE_DTYPES = [ + torch.float32, + torch.float64, +] +INDEX_DTYPES = [torch.int32] +TEST_CASES = [ + (512, 512, 4096), + (1024, 1024, 16384), + (2048, 2048, 65536), + (4096, 4096, 131072), +] +WARMUP = 10 +ITERS = 50 + + +def load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): + """ + Load SuiteSparse / Matrix Market .mtx file into CSR as torch tensors. + Correctly handles *pattern* matrices and *symmetric/skew-symmetric* expansions. + Returns (data, indices, indptr, shape) on device. + """ + import math as _math + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + mm_field = "real" + mm_symmetry = "general" + data_lines = [] + header_info = None + for line in lines: + stripped = line.strip() + if stripped.startswith("%%MatrixMarket"): + tokens = stripped.split() + if len(tokens) >= 5: + mm_field = tokens[3].lower() + mm_symmetry = tokens[4].lower() + continue + if stripped.startswith("%"): + continue + if not header_info and stripped: + parts = stripped.split() + n_rows = int(parts[0]) + n_cols = int(parts[1]) + nnz = int(parts[2]) if len(parts) > 2 else 0 + header_info = (n_rows, n_cols, nnz) + continue + if stripped: + data_lines.append(stripped) + if header_info is None: + raise ValueError(f"Cannot parse .mtx header: {file_path}") + n_rows, n_cols, nnz = header_info + if nnz == 0: + data = torch.tensor([], dtype=dtype, device=device) + indices = torch.tensor([], dtype=torch.int64, device=device) + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + + is_pattern = (mm_field == "pattern") + is_symmetric = mm_symmetry in ("symmetric", "hermitian") + is_skew = (mm_symmetry == "skew-symmetric") + + row_maps = [dict() for _ in range(n_rows)] + for line in data_lines[:nnz]: + parts = line.split() + r = int(parts[0]) - 1 + c = int(parts[1]) - 1 + v = 1.0 if is_pattern else float(parts[2]) + if 0 <= r < n_rows and 0 <= c < n_cols: + row_maps[r][c] = row_maps[r].get(c, 0.0) + v + if r != c: + if is_symmetric and 0 <= c < n_rows and 0 <= r < n_cols: + row_maps[c][r] = row_maps[c].get(r, 0.0) + v + elif is_skew and 0 <= c < n_rows and 0 <= r < n_cols: + row_maps[c][r] = row_maps[c].get(r, 0.0) - v + + cols_s = [] + vals_s = [] + indptr_list = [0] + for r in range(n_rows): + row = row_maps[r] + for c in sorted(row.keys()): + cols_s.append(c) + vals_s.append(row[c]) + indptr_list.append(len(cols_s)) + data = torch.tensor(vals_s, dtype=dtype, device=device) + indices = torch.tensor(cols_s, dtype=torch.int64, device=device) + indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + + +def _allclose_error_ratio(actual, reference, atol, rtol): + if actual.numel() == 0: + return 0.0 + diff = torch.abs(actual - reference).to(torch.float64) + tol = (atol + rtol * torch.abs(reference)).to(torch.float64) + return float(torch.max(diff / tol).item()) + + +def _benchmark_flagsparse_spmv(data, indices, indptr, x, shape, warmup, iters, block_nnz, max_segments): + prepared = ast.prepare_spmv_csr( + data, + indices, + indptr, + shape, + block_nnz=block_nnz, + max_segments=max_segments, + ) + op = lambda: ast.flagsparse_spmv_csr( + x=x, + prepared=prepared, + return_time=False, + ) + y = op() + torch.cuda.synchronize() + for _ in range(warmup): + _ = op() + torch.cuda.synchronize() + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + start_ev.record() + for _ in range(iters): + y = op() + end_ev.record() + torch.cuda.synchronize() + return y, start_ev.elapsed_time(end_ev) / iters + + +def _reference_dtype(dtype): + return torch.float64 if dtype == torch.float32 else dtype + + +def _pytorch_spmv_reference(data, indices, indptr, x, shape, out_dtype): + device = data.device + ref_dtype = _reference_dtype(out_dtype) + data_ref = data.to(ref_dtype) + x_ref = x.to(ref_dtype) + try: + csr_ref = torch.sparse_csr_tensor( + indptr.to(torch.int64), + indices.to(torch.int64), + data_ref, + size=shape, + device=device, + ) + y_ref = torch.sparse.mm(csr_ref, x_ref.unsqueeze(1)).squeeze(1) + except Exception: + n_rows = int(shape[0]) + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + coo_ref = torch.sparse_coo_tensor( + torch.stack([row_ind, indices.to(torch.int64)]), + data_ref, + shape, + device=device, + ).coalesce() + y_ref = torch.sparse.mm(coo_ref, x_ref.unsqueeze(1)).squeeze(1) + return y_ref.to(out_dtype) if ref_dtype != out_dtype else y_ref + + +def _cupy_csr_reference(data, indices, indptr, x, shape, out_dtype): + import cupy as cp + import cupyx.scipy.sparse as cpx + + ref_dtype = _reference_dtype(out_dtype) + data_ref = data.to(ref_dtype) + x_ref = x.to(ref_dtype) + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref)) + ind_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64))) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr)) + x_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(x_ref)) + A_csr_ref = cpx.csr_matrix((data_cp, ind_cp, ptr_cp), shape=shape) + y_ref = A_csr_ref @ x_cp + y_ref_t = torch.utils.dlpack.from_dlpack(y_ref.toDlpack()) + return y_ref_t.to(out_dtype) if ref_dtype != out_dtype else y_ref_t + + +def run_one_mtx( + mtx_path, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=10, + iters=50, + run_cusparse=True, + block_nnz=256, + max_segments=None, +): + """Run SpMV on one .mtx: load, compute ref, run Triton and optional cuSPARSE, return errors and timings.""" + device = torch.device("cuda") + data, indices, indptr, shape = load_mtx_to_csr_torch(mtx_path, dtype=value_dtype, device=device) + indices = indices.to(index_dtype) + n_rows, n_cols = shape + nnz = data.numel() + x = torch.randn(n_cols, dtype=value_dtype, device=device) + atol, rtol = 1e-6, 1e-5 + if value_dtype in (torch.float16, torch.bfloat16): + atol, rtol = 2e-3, 2e-3 + elif value_dtype == torch.float32 or value_dtype == torch.complex64: + # Relaxed for float32: reduction order differs from PyTorch/cuSPARSE on some irregular matrices. + atol, rtol = 1.25e-4, 1.25e-2 + elif value_dtype == torch.float64 or value_dtype == torch.complex128: + atol, rtol = 1e-12, 1e-10 + triton_y, triton_ms = _benchmark_flagsparse_spmv( + data, + indices, + indptr, + x, + shape, + warmup=warmup, + iters=iters, + block_nnz=block_nnz, + max_segments=max_segments, + ) + pt_y = None + pt_ref_y = None + pytorch_ms = None + err_pt = None + triton_ok_pt = False + try: + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + try: + csr_pt = torch.sparse_csr_tensor( + indptr.to(torch.int64), + indices.to(torch.int64), + data, + size=shape, + device=device, + ) + pt_y = torch.sparse.mm(csr_pt, x.unsqueeze(1)).squeeze(1) + torch.cuda.synchronize() + for _ in range(warmup): + _ = torch.sparse.mm(csr_pt, x.unsqueeze(1)) + torch.cuda.synchronize() + start_ev.record() + for _ in range(iters): + _ = torch.sparse.mm(csr_pt, x.unsqueeze(1)) + end_ev.record() + except Exception: + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + coo = torch.sparse_coo_tensor( + torch.stack([row_ind, indices.to(torch.int64)]), + data, + shape, + device=device, + ).coalesce() + pt_y = torch.sparse.mm(coo, x.unsqueeze(1)).squeeze(1) + torch.cuda.synchronize() + for _ in range(warmup): + _ = torch.sparse.mm(coo, x.unsqueeze(1)) + torch.cuda.synchronize() + start_ev.record() + for _ in range(iters): + _ = torch.sparse.mm(coo, x.unsqueeze(1)) + end_ev.record() + if pt_y is not None: + try: + pt_ref_y = _pytorch_spmv_reference( + data, indices, indptr, x, shape, value_dtype + ) + except Exception: + pt_ref_y = pt_y + if pt_ref_y is not None and n_rows: + err_pt = _allclose_error_ratio(triton_y, pt_ref_y, atol, rtol) + triton_ok_pt = (not math.isnan(err_pt)) and err_pt <= 1.0 + torch.cuda.synchronize() + pytorch_ms = start_ev.elapsed_time(end_ev) / iters + except Exception: + pytorch_ms = None + cs_y_t = None + cs_ref_t = None + cusparse_ms = None + err_cu = None + triton_ok_cu = False + csc_ms = None + if run_cusparse and value_dtype not in (torch.bfloat16,): + try: + import cupy as cp + import cupyx.scipy.sparse as cpx + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data)) + ind_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64))) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr)) + x_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + A_csr = cpx.csr_matrix((data_cp, ind_cp, ptr_cp), shape=shape) + torch.cuda.synchronize() + for _ in range(warmup): + _ = A_csr @ x_cp + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + _ = A_csr @ x_cp + end.record() + torch.cuda.synchronize() + cusparse_ms = start.elapsed_time(end) / iters + cs_y = A_csr @ x_cp + cs_y_t = torch.utils.dlpack.from_dlpack(cs_y.toDlpack()) + try: + cs_ref_t = _cupy_csr_reference( + data, indices, indptr, x, shape, value_dtype + ) + except Exception: + cs_ref_t = cs_y_t + if cs_ref_t is not None and n_rows: + err_cu = _allclose_error_ratio(triton_y, cs_ref_t, atol, rtol) + triton_ok_cu = (not math.isnan(err_cu)) and err_cu <= 1.0 + A_csc = A_csr.tocsc() + torch.cuda.synchronize() + for _ in range(warmup): + _ = A_csc @ x_cp + torch.cuda.synchronize() + start.record() + for _ in range(iters): + _ = A_csc @ x_cp + end.record() + torch.cuda.synchronize() + csc_ms = start.elapsed_time(end) / iters + except Exception: + cusparse_ms = None + err_cu = None + csc_ms = None + if pt_y is None and err_cu is None: + return { + "path": mtx_path, + "shape": shape, + "nnz": nnz, + "error": "ref: no PyTorch or cuSPARSE result", + "triton_ms": triton_ms, + "cusparse_ms": None, + "pytorch_ms": None, + "csc_ms": None, + "err_pt": None, + "err_cu": None, + "triton_ok_pt": False, + "triton_ok_cu": False, + "status": "REF_FAIL", + } + status = "PASS" if (triton_ok_pt or triton_ok_cu) else "FAIL" + return { + "path": mtx_path, + "shape": shape, + "nnz": nnz, + "error": None, + "triton_ms": triton_ms, + "cusparse_ms": cusparse_ms, + "pytorch_ms": pytorch_ms, + "csc_ms": csc_ms, + "err_pt": err_pt, + "err_cu": err_cu, + "triton_ok_pt": triton_ok_pt, + "triton_ok_cu": triton_ok_cu, + "status": status, + } + + +def run_mtx_batch( + mtx_paths, + value_dtype=torch.float32, + index_dtype=torch.int32, + warmup=10, + iters=50, + run_cusparse=True, + on_result=None, +): + """Batch run SpMV on multiple .mtx files; return list of result dicts.""" + results = [] + for path in mtx_paths: + r = run_one_mtx( + path, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + ) + results.append(r) + if on_result is not None: + on_result(r) + return results + + +def _dtype_name(dtype): + return str(dtype).replace("torch.", "") + + +def _fmt_ms(v): + return "N/A" if v is None else f"{v:.4f}" + + +def _fmt_speedup(other_ms, triton_ms): + if other_ms is None or triton_ms is None or triton_ms <= 0: + return "N/A" + return f"{other_ms / triton_ms:.2f}x" + + +def _fmt_err(v): + return "N/A" if v is None else f"{v:.2e}" + + +def _status_str(ok, available): + if not available: + return "N/A" + return "PASS" if ok else "FAIL" + + +def _print_mtx_header(value_dtype, index_dtype): + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)}" + ) + print("Formats: FlagSparse=CSR, cuSPARSE=CSR/CSC, PyTorch=CSR or COO.") + print("Timing stays in native dtype. For float32, correctness references use float64 compute then cast.") + print("PT/CU show per-reference correctness. Err(PT)/Err(CU)=max(|diff| / (atol + rtol*|ref|)).") + print("-" * 150) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FlagSparse(ms)':>10} {'CSR(ms)':>10} {'CSC(ms)':>10} {'PyTorch(ms)':>11} " + f"{'FS/CSR':>7} {'FS/PT':>7} {'PT':>6} {'CU':>6} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 150) + + +def _print_mtx_row(r): + name = os.path.basename(r["path"])[:27] + if len(os.path.basename(r["path"])) > 27: + name = name + "…" + n_rows, n_cols = r["shape"] + triton_ms = r.get("triton_ms") + csr_ms = r.get("cusparse_ms") + csc_ms = r.get("csc_ms") + pt_ms = r.get("pytorch_ms") + err_pt_str = _fmt_err(r.get("err_pt")) + err_cu_str = _fmt_err(r.get("err_cu")) + pt_status = _status_str(r.get("triton_ok_pt", False), r.get("err_pt") is not None) + cu_status = _status_str(r.get("triton_ok_cu", False), r.get("err_cu") is not None) + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {r['nnz']:>10} " + f"{_fmt_ms(triton_ms):>10} {_fmt_ms(csr_ms):>10} {_fmt_ms(csc_ms):>10} {_fmt_ms(pt_ms):>11} " + f"{_fmt_speedup(csr_ms, triton_ms):>7} {_fmt_speedup(pt_ms, triton_ms):>7} " + f"{pt_status:>6} {cu_status:>6} {err_pt_str:>10} {err_cu_str:>10}" + ) + + +def print_mtx_results(results, value_dtype, index_dtype): + _print_mtx_header(value_dtype, index_dtype) + for r in results: + _print_mtx_row(r) + print("-" * 150) + + +def _dtype_str(d): + return str(d).replace("torch.", "") + + +def run_all_dtypes_export_csv(paths, csv_path, warmup=10, iters=50, run_cusparse=True): + """Run SpMV for all VALUE_DTYPES x INDEX_DTYPES on each .mtx and write results to CSV.""" + rows = [] + for value_dtype in VALUE_DTYPES: + for index_dtype in INDEX_DTYPES: + print("=" * 150) + _print_mtx_header(value_dtype, index_dtype) + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=warmup, + iters=iters, + run_cusparse=run_cusparse, + on_result=_print_mtx_row, + ) + print("-" * 150) + for r in results: + n_rows, n_cols = r["shape"] + rows.append({ + "matrix": os.path.basename(r["path"]), + "value_dtype": _dtype_str(value_dtype), + "index_dtype": _dtype_str(index_dtype), + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": r["nnz"], + "triton_ms": r.get("triton_ms"), + "cusparse_ms": r.get("cusparse_ms"), + "pytorch_ms": r.get("pytorch_ms"), + "csc_ms": r.get("csc_ms"), + "pt_status": _status_str(r.get("triton_ok_pt", False), r.get("err_pt") is not None), + "cu_status": _status_str(r.get("triton_ok_cu", False), r.get("err_cu") is not None), + "status": r.get("status", r.get("error", "")), + "err_pt": r.get("err_pt"), + "err_cu": r.get("err_cu"), + }) + fieldnames = [ + "matrix", "value_dtype", "index_dtype", "n_rows", "n_cols", "nnz", + "triton_ms", "cusparse_ms", "pytorch_ms", "csc_ms", + "pt_status", "cu_status", "status", "err_pt", "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore") + w.writeheader() + for row in rows: + w.writerow({k: ("" if v is None else v) for k, v in row.items()}) + print(f"Wrote {len(rows)} rows to {csv_path}") + + +def run_comprehensive_synthetic(): + """Synthetic benchmark with per-case table (like test_gather).""" + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + print("=" * 110) + print("FLAGSPARSE SpMV BENCHMARK (synthetic CSR)") + print("=" * 110) + print(f"GPU: {torch.cuda.get_device_name(0)} | Warmup: {WARMUP} Iters: {ITERS}") + print("Formats: FlagSparse=CSR, cuSPARSE=CSR (when supported), Reference=CuPy CSR or PyTorch COO") + print("When CuPy does not support dtype (e.g. bfloat16/float16), reference = PyTorch (float32 then cast).") + print() + total = 0 + failed = 0 + for value_dtype in VALUE_DTYPES: + for index_dtype in INDEX_DTYPES: + print("-" * 110) + print( + f"Value dtype: {_dtype_name(value_dtype):<12} | Index dtype: {_dtype_name(index_dtype):<6}" + ) + print("-" * 110) + print( + f"{'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FlagSparse(ms)':>11} {'cuSPARSE(ms)':>12} {'FS/CS':>8} " + f"{'Status':>6} {'Err(FS)':>10} {'Err(CS)':>10}" + ) + print("-" * 110) + for n_rows, n_cols, nnz in TEST_CASES: + result = ast.benchmark_spmv_case( + n_rows=n_rows, + n_cols=n_cols, + nnz=nnz, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=WARMUP, + iters=ITERS, + run_cusparse=True, + ) + total += 1 + perf = result["performance"] + verify = result["verification"] + backend = result["backend_status"] + ok = verify["triton_match_reference"] + cs_ok = verify.get("cusparse_match_reference") + status = "PASS" if (ok and (cs_ok is None or cs_ok)) else "FAIL" + if not ok or (cs_ok is False): + failed += 1 + triton_ms = perf["triton_ms"] + cusparse_ms = perf["cusparse_ms"] + speedup = perf.get("triton_speedup_vs_cusparse") + if speedup is not None and speedup != "N/A": + speedup_str = f"{speedup:.2f}x" + else: + speedup_str = "N/A" + print( + f"{n_rows:>7} {n_cols:>7} {nnz:>10} " + f"{_fmt_ms(triton_ms):>11} {_fmt_ms(cusparse_ms):>12} {speedup_str:>8} " + f"{status:>6} {_fmt_err(verify.get('triton_max_error')):>10} " + f"{_fmt_err(verify.get('cusparse_max_error')):>10}" + ) + print("-" * 110) + if backend.get("cusparse_unavailable_reason"): + print(f" cuSPARSE: {backend['cusparse_unavailable_reason']}") + print(" Reference: PyTorch (float32 compute then cast to value dtype).") + print() + print("=" * 110) + print(f"Total: {total} Failed: {failed}") + print("=" * 110) + + +def main(): + parser = argparse.ArgumentParser( + description="SpMV test: SuiteSparse .mtx batch run, error and performance." + ) + parser.add_argument( + "mtx", + nargs="*", + help=".mtx file path(s), or directory(ies) to glob for *.mtx", + ) + parser.add_argument( + "--synthetic", + action="store_true", + help="Run synthetic benchmark instead of .mtx", + ) + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "bfloat16", "float32", "float64", "complex64", "complex128"], + help="Value dtype (default: float32)", + ) + parser.add_argument( + "--index-dtype", + default="int32", + choices=["int32", "int64"], + help="Index dtype (default: int32)", + ) + parser.add_argument("--warmup", type=int, default=10, help="Warmup runs") + parser.add_argument("--iters", type=int, default=50, help="Timing iterations") + parser.add_argument("--no-cusparse", action="store_true", help="Skip cuSPARSE baseline") + parser.add_argument( + "--csv-csr", + type=str, + default=None, + metavar="FILE", + help="Run all value_dtype x index_dtype on all .mtx (CSR) and write results to CSV", + ) + args = parser.parse_args() + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + "float64": torch.float64, + "complex64": torch.complex64, + "complex128": torch.complex128, + } + index_map = {"int32": torch.int32, "int64": torch.int64} + value_dtype = dtype_map[args.dtype] + index_dtype = index_map[args.index_dtype] + if args.synthetic: + run_comprehensive_synthetic() + return + paths = [] + for p in args.mtx: + if os.path.isfile(p) and p.endswith(".mtx"): + paths.append(p) + elif os.path.isdir(p): + paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + if not paths and not args.csv_csr: + print("No .mtx files given. Use: python test_spmv.py [file2.mtx ...] or ") + print("Or run synthetic: python test_spmv.py --synthetic") + print("Or run all dtypes and export CSR CSV: python test_spmv.py --csv-csr results.csv") + return + if args.csv_csr is not None: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found. Specify files or a directory.") + return + print("=" * 80) + print("FLAGSPARSE SpMV (CSR) — all dtypes, export to CSV") + print("=" * 80) + print(f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)} | CSV: {args.csv_csr}") + run_all_dtypes_export_csv( + paths, + args.csv_csr, + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + ) + return + print("=" * 120) + print("FLAGSPARSE SpMV — SuiteSparse .mtx batch (error + performance)") + print("=" * 120) + print(f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)}") + print(f"dtype: {args.dtype} index_dtype: {args.index_dtype} warmup: {args.warmup} iters: {args.iters}") + print() + results = run_mtx_batch( + paths, + value_dtype=value_dtype, + index_dtype=index_dtype, + warmup=args.warmup, + iters=args.iters, + run_cusparse=not args.no_cusparse, + ) + print_mtx_results(results, value_dtype, index_dtype) + passed = sum(1 for r in results if r.get("status") == "PASS") + print(f"Passed: {passed} / {len(results)}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_spmv_coo.py b/tests/test_spmv_coo.py new file mode 100644 index 0000000..9dcc1cf --- /dev/null +++ b/tests/test_spmv_coo.py @@ -0,0 +1,589 @@ +"""SpMV COO tests: float32/float64 + int32, synthetic + optional .mtx, compare FlagSparse vs PyTorch/CuPy.""" +import argparse +import glob +import csv +import math +import os + +import torch +import flagsparse as fs +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse +except Exception: + cp = None + cpx_sparse = None + +VALUE_DTYPES = [torch.float32, torch.float64] +INDEX_DTYPE = torch.int32 +TEST_SIZES = [(512, 512), (1024, 1024), (2048, 2048)] +WARMUP = 10 +ITERS = 50 + + +def _dtype_name(dtype): + return str(dtype).replace("torch.", "") + + +def _fmt_ms(v): + return "N/A" if v is None else f"{v:.4f}" + + +def _fmt_speedup(other_ms, triton_ms): + if other_ms is None or triton_ms is None or triton_ms <= 0: + return "N/A" + return f"{other_ms / triton_ms:.2f}x" + + +def _fmt_err(v): + return "N/A" if v is None else f"{v:.2e}" + + +def _status_str(ok, available): + if not available: + return "N/A" + return "PASS" if ok else "FAIL" + + +def _allclose_error_ratio(actual, reference, atol, rtol): + if actual.numel() == 0: + return 0.0 + diff = torch.abs(actual - reference).to(torch.float64) + tol = (atol + rtol * torch.abs(reference)).to(torch.float64) + return float(torch.max(diff / tol).item()) + + +def _reference_dtype(dtype): + return torch.float64 if dtype == torch.float32 else dtype + + +def _pytorch_coo_reference(data, row, col, x, shape, out_dtype): + ref_dtype = _reference_dtype(out_dtype) + data_ref = data.to(ref_dtype) + x_ref = x.to(ref_dtype) + coo_ref = torch.sparse_coo_tensor( + torch.stack([row.to(torch.int64), col.to(torch.int64)]), + data_ref, + shape, + device=data.device, + ).coalesce() + y_ref = torch.sparse.mm(coo_ref, x_ref.unsqueeze(1)).squeeze(1) + return y_ref.to(out_dtype) if ref_dtype != out_dtype else y_ref + + +def _cupy_coo_reference(data, row, col, x, shape, out_dtype): + ref_dtype = _reference_dtype(out_dtype) + data_ref = data.to(ref_dtype) + x_ref = x.to(ref_dtype) + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref)) + row_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(row.to(torch.int64))) + col_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.to(torch.int64))) + x_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(x_ref)) + A_cp_ref = cpx_sparse.coo_matrix((data_cp, (row_cp, col_cp)), shape=shape) + y_ref = A_cp_ref @ x_cp + y_ref_t = torch.utils.dlpack.from_dlpack(y_ref.toDlpack()) + return y_ref_t.to(out_dtype) if ref_dtype != out_dtype else y_ref_t + + +def _dense_to_coo(A): + rows, cols = A.nonzero(as_tuple=True) + data = A[rows, cols] + return data, rows, cols + + +COO_SEP = "-" * 172 +COO_HEADER = ( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'Base(ms)':>9} {'Opt(ms)':>9} {'PT(ms)':>9} {'CU(ms)':>9} " + f"{'Opt/Base':>8} {'Opt/PT':>8} {'Opt/CU':>8} " + f"{'Err(Base)':>10} {'Err(Opt)':>10} {'Status':>6}" +) + + +def _spd(num, den): + if num is None or den is None or den <= 0: + return "N/A" + return f"{num / den:.2f}x" + + +# FlagSparse native COO SpMV: see sparse_operations.spmv_coo +COO_ATOMIC_BLOCK = 256 +COO_ATOMIC_WARPS = 4 +COO_SEG_BLOCK_INNER = 128 + +def _timed_flagsparse_coo(prepared, x, warmup, iters): + op = lambda: fs.flagsparse_spmv_coo( + x=x, + prepared=prepared, + return_time=False, + block_inner=COO_SEG_BLOCK_INNER, + block_size=COO_ATOMIC_BLOCK, + num_warps=COO_ATOMIC_WARPS, + ) + y = op() + torch.cuda.synchronize() + for _ in range(warmup): + op() + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + for _ in range(iters): + y = op() + e1.record() + torch.cuda.synchronize() + return y, e0.elapsed_time(e1) / iters + + +def run_synthetic(): + if not torch.cuda.is_available(): + print("CUDA is not available. Please run on a GPU-enabled system.") + return + device = torch.device("cuda") + print("=" * 172) + print("FLAGSPARSE SpMV COO BENCHMARK (synthetic dense -> COO). All backends stay COO.") + print("=" * 172) + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"Warmup: {WARMUP} | Iters: {ITERS}") + print() + + for dtype in VALUE_DTYPES: + atol, rtol = _tol_for_dtype(dtype) + print(COO_SEP) + print(f"dtype: {_dtype_name(dtype)} index_dtype: int32") + print(COO_SEP) + print( + "FlagSparse: prepare_spmv_coo + Triton COO SpMV (no CSR). " + "Base(ms) = row-run (seg) kernel; Opt(ms) = NNZ atomic kernel." + ) + print( + "PyTorch/CuPy: COO sparse.mm / cupyx coo_matrix @ x (no tocsr). " + "Err vs PyTorch COO reference in fp64 (fp32 casts back)." + ) + print(COO_SEP) + print( + f"{'M':>6} {'N':>6} {'NNZ':>10} " + f"{'Base(ms)':>9} {'Opt(ms)':>9} {'PT(ms)':>9} {'CU(ms)':>9} " + f"{'Opt/Base':>8} {'Opt/PT':>8} {'Opt/CU':>8} " + f"{'Err(Base)':>10} {'Err(Opt)':>10} {'Status':>6}" + ) + print(COO_SEP) + for m, n in TEST_SIZES: + A = torch.randn(m, n, dtype=dtype, device=device) + A *= (torch.rand_like(A) < 0.1) + data, row, col = _dense_to_coo(A) + nnz = int(data.numel()) + x = torch.randn(n, dtype=dtype, device=device) + prepared_seg = fs.prepare_spmv_coo( + data, row, col, (m, n), sort_by_row=True + ) + prepared_at = fs.prepare_spmv_coo( + data, row, col, (m, n), sort_by_row=False + ) + y_base, base_ms = _timed_flagsparse_coo( + prepared_seg, x, WARMUP, ITERS + ) + y_opt, opt_ms = _timed_flagsparse_coo( + prepared_at, x, WARMUP, ITERS + ) + y_ref = _pytorch_coo_reference(data, row, col, x, (m, n), dtype) + err_base = _allclose_error_ratio(y_base, y_ref, atol, rtol) + err_opt = _allclose_error_ratio(y_opt, y_ref, atol, rtol) + st = ( + "PASS" + if ( + (not math.isnan(err_base)) + and (not math.isnan(err_opt)) + and err_base <= 1.0 + and err_opt <= 1.0 + ) + else "FAIL" + ) + if nnz > 0: + coo = torch.sparse_coo_tensor( + torch.stack([row.to(torch.int64), col.to(torch.int64)]), + data, + (m, n), + device=device, + ).coalesce() + torch.cuda.synchronize() + for _ in range(WARMUP): + _ = torch.sparse.mm(coo, x.unsqueeze(1)).squeeze(1) + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + for _ in range(ITERS): + y_pt = torch.sparse.mm(coo, x.unsqueeze(1)).squeeze(1) + e1.record() + torch.cuda.synchronize() + pt_ms = e0.elapsed_time(e1) / ITERS + else: + y_pt = torch.zeros(m, dtype=dtype, device=device) + pt_ms = 0.0 + cu_ms = None + if cp is not None and cpx_sparse is not None: + try: + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data)) + row_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(row.to(torch.int64))) + col_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.to(torch.int64))) + A_cp = cpx_sparse.coo_matrix( + (data_cp, (row_cp, col_cp)), shape=(m, n) + ) + x_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + for _ in range(WARMUP): + _ = A_cp @ x_cp + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(ITERS): + y_cu = A_cp @ x_cp + c1.record() + c1.synchronize() + cu_ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + except Exception: + cu_ms = None + print( + f"{m:>6} {n:>6} {nnz:>10} " + f"{_fmt_ms(base_ms):>9} {_fmt_ms(opt_ms):>9} {_fmt_ms(pt_ms):>9} {_fmt_ms(cu_ms):>9} " + f"{_spd(base_ms, opt_ms):>8} {_spd(pt_ms, opt_ms):>8} {_spd(cu_ms, opt_ms):>8} " + f"{_fmt_err(err_base):>10} {_fmt_err(err_opt):>10} {st:>6}" + ) + print(COO_SEP) + print() + + +def _load_mtx_to_coo_torch(file_path, dtype=torch.float32, device=None): + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + mm_field = "real" + mm_symmetry = "general" + data_lines = [] + header_info = None + for line in lines: + stripped = line.strip() + if stripped.startswith("%%MatrixMarket"): + tokens = stripped.split() + if len(tokens) >= 5: + mm_field = tokens[3].lower() + mm_symmetry = tokens[4].lower() + continue + if stripped.startswith("%"): + continue + if not header_info and stripped: + parts = stripped.split() + n_rows = int(parts[0]) + n_cols = int(parts[1]) + nnz = int(parts[2]) if len(parts) > 2 else 0 + header_info = (n_rows, n_cols, nnz) + continue + if stripped: + data_lines.append(stripped) + if header_info is None: + raise ValueError(f"Cannot parse .mtx header: {file_path}") + n_rows, n_cols, nnz = header_info + + is_pattern = (mm_field == "pattern") + is_symmetric = mm_symmetry in ("symmetric", "hermitian") + is_skew = (mm_symmetry == "skew-symmetric") + + rows_host = [] + cols_host = [] + vals_host = [] + for line in data_lines[:nnz]: + parts = line.split() + if len(parts) < 2: + continue + r = int(parts[0]) - 1 + c = int(parts[1]) - 1 + v = 1.0 if is_pattern else (float(parts[2]) if len(parts) >= 3 else 0.0) + if 0 <= r < n_rows and 0 <= c < n_cols: + rows_host.append(r) + cols_host.append(c) + vals_host.append(v) + if r != c: + if is_symmetric and 0 <= c < n_rows and 0 <= r < n_cols: + rows_host.append(c) + cols_host.append(r) + vals_host.append(v) + elif is_skew and 0 <= c < n_rows and 0 <= r < n_cols: + rows_host.append(c) + cols_host.append(r) + vals_host.append(-v) + rows = torch.tensor(rows_host, dtype=torch.int64, device=device) + cols = torch.tensor(cols_host, dtype=torch.int64, device=device) + vals = torch.tensor(vals_host, dtype=dtype, device=device) + return vals, rows, cols, (n_rows, n_cols) + + +def _tol_for_dtype(dtype): + if dtype == torch.float32: + return 1e-4, 1e-2 + return 1e-12, 1e-10 + + +# Dense PyTorch reference for SpSV can OOM on large matrices. +DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB + + +def _allow_dense_pytorch_ref(shape, dtype): + n_rows, n_cols = shape + elem_bytes = torch.empty((), dtype=dtype).element_size() + dense_bytes = int(n_rows) * int(n_cols) * int(elem_bytes) + return dense_bytes <= DENSE_REF_MAX_BYTES + + +def run_all_dtypes_coo_csv(mtx_paths, csv_path): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + rows_out = [] + for dtype in VALUE_DTYPES: + atol, rtol = _tol_for_dtype(dtype) + print("=" * 172) + print(f"Value dtype: {_dtype_name(dtype)} | Index dtype: int32") + print("Input: MatrixMarket → COO. FlagSparse: COO Triton only (seg + atomic), no CSR.") + print("PyTorch = COO sparse.mm; CuPy = COO matvec (coo_matrix @ x, no tocsr).") + print( + "Base(ms) = FlagSparse COO row-run (seg); Opt(ms) = FlagSparse COO NNZ atomic; " + "PT/CU = COO baselines." + ) + print( + f"prepare_spmv_coo once per variant + {WARMUP} warmup + " + f"{ITERS} CUDA-event-averaged SpMV per backend." + ) + print(COO_SEP) + print(COO_HEADER) + print(COO_SEP) + for path in mtx_paths: + try: + data, row, col, shape = _load_mtx_to_coo_torch( + path, dtype=dtype, device=device + ) + m, n = shape + x = torch.randn(n, dtype=dtype, device=device) + prepared_seg = fs.prepare_spmv_coo( + data, row, col, shape, sort_by_row=True + ) + prepared_at = fs.prepare_spmv_coo( + data, row, col, shape, sort_by_row=False + ) + y_base, base_ms = _timed_flagsparse_coo( + prepared_seg, x, WARMUP, ITERS + ) + y_opt, opt_ms = _timed_flagsparse_coo( + prepared_at, x, WARMUP, ITERS + ) + y_ref = _pytorch_coo_reference(data, row, col, x, shape, dtype) + err_base = _allclose_error_ratio(y_base, y_ref, atol, rtol) + err_opt = _allclose_error_ratio(y_opt, y_ref, atol, rtol) + err_pt = None + err_cu = None + opt_ok = ( + (not math.isnan(err_base)) + and (not math.isnan(err_opt)) + and err_base <= 1.0 + and err_opt <= 1.0 + ) + if data.numel() > 0: + coo = torch.sparse_coo_tensor( + torch.stack([row, col]), + data, + shape, + device=device, + ).coalesce() + torch.cuda.synchronize() + for _ in range(WARMUP): + _ = torch.sparse.mm(coo, x.unsqueeze(1)).squeeze(1) + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + for _ in range(ITERS): + y_pt = torch.sparse.mm(coo, x.unsqueeze(1)).squeeze(1) + e1.record() + torch.cuda.synchronize() + pt_ms = e0.elapsed_time(e1) / ITERS + err_pt = _allclose_error_ratio(y_opt, y_pt, atol, rtol) + else: + pt_ms = 0.0 + cu_ms = None + triton_ok_pt = False + triton_ok_cu = False + if cp is not None and cpx_sparse is not None: + try: + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data)) + row_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(row)) + col_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col)) + A_cp = cpx_sparse.coo_matrix( + (data_cp, (row_cp, col_cp)), shape=shape + ) + x_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + for _ in range(WARMUP): + _ = A_cp @ x_cp + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(ITERS): + y_cu = A_cp @ x_cp + c1.record() + c1.synchronize() + cu_ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + y_cu_t = torch.utils.dlpack.from_dlpack(y_cu.toDlpack()) + err_cu = _allclose_error_ratio(y_opt, y_cu_t, atol, rtol) + triton_ok_cu = (not math.isnan(err_cu)) and err_cu <= 1.0 + except Exception: + cu_ms = None + err_cu = None + if err_pt is not None: + triton_ok_pt = (not math.isnan(err_pt)) and err_pt <= 1.0 + status = "PASS" if opt_ok else "FAIL" + pt_status = _status_str(triton_ok_pt, err_pt is not None) + cu_status = _status_str(triton_ok_cu, err_cu is not None) + rows_out.append( + { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(dtype), + "index_dtype": "torch.int32", + "n_rows": m, + "n_cols": n, + "nnz": int(data.numel()), + "base_ms": base_ms, + "opt_ms": opt_ms, + "triton_ms": opt_ms, + "cusparse_ms": cu_ms, + "pytorch_ms": pt_ms, + "csc_ms": None, + "pt_status": pt_status, + "cu_status": cu_status, + "status": status, + "err_base": err_base, + "err_opt": err_opt, + "err_pt": err_pt, + "err_cu": err_cu, + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {m:>7} {n:>7} {int(data.numel()):>10} " + f"{_fmt_ms(base_ms):>9} {_fmt_ms(opt_ms):>9} {_fmt_ms(pt_ms):>9} {_fmt_ms(cu_ms):>9} " + f"{_spd(base_ms, opt_ms):>8} {_spd(pt_ms, opt_ms):>8} {_spd(cu_ms, opt_ms):>8} " + f"{_fmt_err(err_base):>10} {_fmt_err(err_opt):>10} {status:>6}" + ) + except Exception as e: + rows_out.append( + { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(dtype), + "index_dtype": "torch.int32", + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "base_ms": None, + "opt_ms": None, + "triton_ms": None, + "cusparse_ms": None, + "pytorch_ms": None, + "csc_ms": None, + "status": "ERROR", + "err_base": None, + "err_opt": None, + "err_pt": None, + "err_cu": None, + "pt_status": "N/A", + "cu_status": "N/A", + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " + f"{_fmt_ms(None):>9} {_fmt_ms(None):>9} {_fmt_ms(None):>9} {_fmt_ms(None):>9} " + f"{'N/A':>8} {'N/A':>8} {'N/A':>8} " + f"{_fmt_err(None):>10} {_fmt_err(None):>10} {'ERROR':>6}" + ) + print(f" ERROR: {e}") + print(COO_SEP) + fieldnames = [ + "matrix", + "value_dtype", + "index_dtype", + "n_rows", + "n_cols", + "nnz", + "base_ms", + "opt_ms", + "triton_ms", + "cusparse_ms", + "pytorch_ms", + "csc_ms", + "pt_status", + "cu_status", + "status", + "err_base", + "err_opt", + "err_pt", + "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="SpMV COO test: synthetic dense->COO and optional .mtx, export CSV." + ) + parser.add_argument( + "mtx", + nargs="*", + help=".mtx file path(s), or directory(ies) to glob for *.mtx", + ) + parser.add_argument( + "--synthetic", action="store_true", help="Run synthetic dense->COO tests" + ) + parser.add_argument( + "--csv-coo", + type=str, + default=None, + metavar="FILE", + help="Run all dtypes on given .mtx and export COO SpMV results to CSV", + ) + args = parser.parse_args() + + if args.synthetic: + run_synthetic() + return + + paths = [] + for p in args.mtx: + if os.path.isfile(p) and p.endswith(".mtx"): + paths.append(p) + elif os.path.isdir(p): + paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + if args.csv_coo: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --csv-coo") + return + run_all_dtypes_coo_csv(paths, args.csv_coo) + return + + print("Use --synthetic or --csv-coo to run COO SpMV tests.") + + +if __name__ == "__main__": + main() diff --git a/tests/test_spmv_opt.py b/tests/test_spmv_opt.py new file mode 100644 index 0000000..4ce1f3a --- /dev/null +++ b/tests/test_spmv_opt.py @@ -0,0 +1,404 @@ +""" +SpMV optimisation A/B test: compare _impl (baseline) vs _impl_opt (optimised) +side-by-side, together with PyTorch and cuSPARSE baselines. + +Usage: + python tests/test_spmv_opt.py # batch run, default float32 + python tests/test_spmv_opt.py --csv opt.csv # all dtypes, export CSV +""" +import argparse +import csv +import glob +import math +import os + +import torch +import flagsparse as fs + +VALUE_DTYPES = [torch.float32, torch.float64] +INDEX_DTYPES = [torch.int32] +WARMUP = 10 +ITERS = 50 + + +def load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): + import math as _math + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + mm_field = "real" + mm_symmetry = "general" + data_lines = [] + header_info = None + for line in lines: + stripped = line.strip() + if stripped.startswith("%%MatrixMarket"): + tokens = stripped.split() + if len(tokens) >= 5: + mm_field = tokens[3].lower() + mm_symmetry = tokens[4].lower() + continue + if stripped.startswith("%"): + continue + if not header_info and stripped: + parts = stripped.split() + n_rows = int(parts[0]) + n_cols = int(parts[1]) + nnz = int(parts[2]) if len(parts) > 2 else 0 + header_info = (n_rows, n_cols, nnz) + continue + if stripped: + data_lines.append(stripped) + if header_info is None: + raise ValueError(f"Cannot parse .mtx header: {file_path}") + n_rows, n_cols, nnz = header_info + if nnz == 0: + data = torch.tensor([], dtype=dtype, device=device) + indices = torch.tensor([], dtype=torch.int64, device=device) + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + is_pattern = (mm_field == "pattern") + is_symmetric = mm_symmetry in ("symmetric", "hermitian") + is_skew = (mm_symmetry == "skew-symmetric") + row_maps = [dict() for _ in range(n_rows)] + for line in data_lines[:nnz]: + parts = line.split() + r = int(parts[0]) - 1 + c = int(parts[1]) - 1 + v = 1.0 if is_pattern else float(parts[2]) + if 0 <= r < n_rows and 0 <= c < n_cols: + row_maps[r][c] = row_maps[r].get(c, 0.0) + v + if r != c: + if is_symmetric and 0 <= c < n_rows and 0 <= r < n_cols: + row_maps[c][r] = row_maps[c].get(r, 0.0) + v + elif is_skew and 0 <= c < n_rows and 0 <= r < n_cols: + row_maps[c][r] = row_maps[c].get(r, 0.0) - v + cols_s = [] + vals_s = [] + indptr_list = [0] + for r in range(n_rows): + row = row_maps[r] + for c in sorted(row.keys()): + cols_s.append(c) + vals_s.append(row[c]) + indptr_list.append(len(cols_s)) + data = torch.tensor(vals_s, dtype=dtype, device=device) + indices = torch.tensor(cols_s, dtype=torch.int64, device=device) + indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + + +def _timed_spmv(prepared, x, warmup, iters, use_opt): + op = lambda: fs.flagsparse_spmv_csr( + x=x, + prepared=prepared, + return_time=False, use_opt=use_opt, + ) + y = op() + torch.cuda.synchronize() + for _ in range(warmup): + op() + torch.cuda.synchronize() + e0 = torch.cuda.Event(enable_timing=True) + e1 = torch.cuda.Event(enable_timing=True) + e0.record() + for _ in range(iters): + y = op() + e1.record() + torch.cuda.synchronize() + return y, e0.elapsed_time(e1) / iters + + +def _timed_pytorch(data, indices, indptr, x, shape, warmup, iters): + device = data.device + n_rows = int(shape[0]) + try: + A = torch.sparse_csr_tensor( + indptr.to(torch.int64), indices.to(torch.int64), data, + size=shape, device=device, + ) + y = torch.sparse.mm(A, x.unsqueeze(1)).squeeze(1) + op = lambda: torch.sparse.mm(A, x.unsqueeze(1)).squeeze(1) + except Exception: + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + A = torch.sparse_coo_tensor( + torch.stack([row_ind, indices.to(torch.int64)]), + data, shape, device=device, + ).coalesce() + y = torch.sparse.mm(A, x.unsqueeze(1)).squeeze(1) + op = lambda: torch.sparse.mm(A, x.unsqueeze(1)).squeeze(1) + torch.cuda.synchronize() + for _ in range(warmup): + op() + torch.cuda.synchronize() + e0 = torch.cuda.Event(enable_timing=True) + e1 = torch.cuda.Event(enable_timing=True) + e0.record() + for _ in range(iters): + op() + e1.record() + torch.cuda.synchronize() + return y, e0.elapsed_time(e1) / iters + + +def _timed_cusparse(data, indices, indptr, x, shape, warmup, iters): + import cupy as cp + import cupyx.scipy.sparse as cpx + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data)) + ind_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indices.to(torch.int64))) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr)) + x_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(x)) + A = cpx.csr_matrix((data_cp, ind_cp, ptr_cp), shape=shape) + torch.cuda.synchronize() + for _ in range(warmup): + _ = A @ x_cp + torch.cuda.synchronize() + e0 = torch.cuda.Event(enable_timing=True) + e1 = torch.cuda.Event(enable_timing=True) + e0.record() + for _ in range(iters): + _ = A @ x_cp + e1.record() + torch.cuda.synchronize() + y_cp = A @ x_cp + y = torch.utils.dlpack.from_dlpack(y_cp.toDlpack()) + return y, e0.elapsed_time(e1) / iters + + +def _fmt(v): + return "N/A" if v is None else f"{v:.4f}" + + +def _spd(base, other): + if base is None or other is None or other <= 0: + return "N/A" + return f"{base / other:.2f}x" + + +def _err(v): + return "N/A" if v is None else f"{v:.2e}" + + +HEADER = ( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'Base(ms)':>9} {'Opt(ms)':>9} {'PT(ms)':>9} {'CU(ms)':>9} " + f"{'Opt/Base':>8} {'Opt/PT':>8} {'Opt/CU':>8} " + f"{'Err(Base)':>10} {'Err(Opt)':>10} {'Status':>6}" +) +SEP = "-" * 170 + + +def run_one_mtx(path, dtype, index_dtype, warmup, iters): + device = torch.device("cuda") + data, indices, indptr, shape = load_mtx_to_csr_torch(path, dtype=dtype, device=device) + indices = indices.to(index_dtype) + n_rows, n_cols = shape + nnz = data.numel() + x = torch.randn(n_cols, dtype=dtype, device=device) + prepared = fs.prepare_spmv_csr(data, indices, indptr, shape) + + if dtype == torch.float32: + atol, rtol = 1e-4, 1e-2 + else: + atol, rtol = 1e-12, 1e-10 + + # ── Reference (float64 accumulation via PyTorch) ── + try: + ref_dtype = torch.float64 if dtype == torch.float32 else dtype + data_ref = data.to(ref_dtype) + x_ref = x.to(ref_dtype) + try: + A_ref = torch.sparse_csr_tensor( + indptr.to(torch.int64), indices.to(torch.int64), data_ref, + size=shape, device=device, + ) + y_ref = torch.sparse.mm(A_ref, x_ref.unsqueeze(1)).squeeze(1).to(dtype) + except Exception: + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + A_ref = torch.sparse_coo_tensor( + torch.stack([row_ind, indices.to(torch.int64)]), + data_ref, shape, device=device, + ).coalesce() + y_ref = torch.sparse.mm(A_ref, x_ref.unsqueeze(1)).squeeze(1).to(dtype) + except Exception: + y_ref = None + + # ── Baseline (use_opt=False) ── + y_base, base_ms = _timed_spmv(prepared, x, warmup, iters, use_opt=False) + + # ── Optimised (use_opt=True) ── + y_opt, opt_ms = _timed_spmv(prepared, x, warmup, iters, use_opt=True) + + # ── PyTorch ── + pt_ms = None + try: + _, pt_ms = _timed_pytorch(data, indices, indptr, x, shape, warmup, iters) + except Exception: + pass + + # ── cuSPARSE ── + cu_ms = None + try: + _, cu_ms = _timed_cusparse(data, indices, indptr, x, shape, warmup, iters) + except Exception: + pass + + # ── Correctness vs reference ── + err_base = None + err_opt = None + if y_ref is not None and n_rows > 0: + diff_b = torch.abs(y_base - y_ref).to(torch.float64) + diff_o = torch.abs(y_opt - y_ref).to(torch.float64) + tol = (atol + rtol * torch.abs(y_ref).to(torch.float64)) + err_base = float(torch.max(diff_b / tol).item()) + err_opt = float(torch.max(diff_o / tol).item()) + + base_ok = err_base is not None and (not math.isnan(err_base)) and err_base <= 1.0 + opt_ok = err_opt is not None and (not math.isnan(err_opt)) and err_opt <= 1.0 + status = "PASS" if opt_ok else "FAIL" + + return { + "path": path, "shape": shape, "nnz": nnz, + "base_ms": base_ms, "opt_ms": opt_ms, + "pt_ms": pt_ms, "cu_ms": cu_ms, + "err_base": err_base, "err_opt": err_opt, + "base_ok": base_ok, "opt_ok": opt_ok, + "status": status, + } + + +def print_row(r): + name = os.path.basename(r["path"])[:27] + n_rows, n_cols = r["shape"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {r['nnz']:>10} " + f"{_fmt(r['base_ms']):>9} {_fmt(r['opt_ms']):>9} " + f"{_fmt(r['pt_ms']):>9} {_fmt(r['cu_ms']):>9} " + f"{_spd(r['base_ms'], r['opt_ms']):>8} " + f"{_spd(r['pt_ms'], r['opt_ms']):>8} " + f"{_spd(r['cu_ms'], r['opt_ms']):>8} " + f"{_err(r['err_base']):>10} {_err(r['err_opt']):>10} {r['status']:>6}" + ) + + +def run_batch(paths, dtype, index_dtype, warmup, iters): + results = [] + for p in paths: + try: + r = run_one_mtx(p, dtype, index_dtype, warmup, iters) + except Exception as e: + print(f" ERROR on {os.path.basename(p)}: {e}") + continue + results.append(r) + print_row(r) + return results + + +def run_all_csv(paths, csv_path, warmup, iters): + all_rows = [] + for dtype in VALUE_DTYPES: + for idx_dtype in INDEX_DTYPES: + dname = str(dtype).replace("torch.", "") + iname = str(idx_dtype).replace("torch.", "") + print("=" * 170) + print(f"Value dtype: {dname} | Index dtype: {iname}") + print( + "Base = FlagSparse baseline (fp64-accum for fp32). " + "Opt = FlagSparse CSR-Vector (fp32/fp64 native accum, wide tiles, few launches). " + "Speedup = Base/Opt or Ref/Opt." + ) + print(SEP) + print(HEADER) + print(SEP) + results = run_batch(paths, dtype, idx_dtype, warmup, iters) + print(SEP) + for r in results: + n_rows, n_cols = r["shape"] + all_rows.append({ + "matrix": os.path.basename(r["path"]), + "value_dtype": dname, + "index_dtype": iname, + "n_rows": n_rows, "n_cols": n_cols, "nnz": r["nnz"], + "base_ms": r["base_ms"], "opt_ms": r["opt_ms"], + "pt_ms": r["pt_ms"], "cu_ms": r["cu_ms"], + "opt_vs_base": r["base_ms"] / r["opt_ms"] if r["opt_ms"] and r["opt_ms"] > 0 else None, + "opt_vs_pt": r["pt_ms"] / r["opt_ms"] if r["pt_ms"] and r["opt_ms"] and r["opt_ms"] > 0 else None, + "opt_vs_cu": r["cu_ms"] / r["opt_ms"] if r["cu_ms"] and r["opt_ms"] and r["opt_ms"] > 0 else None, + "err_base": r["err_base"], "err_opt": r["err_opt"], + "status": r["status"], + }) + fields = [ + "matrix", "value_dtype", "index_dtype", + "n_rows", "n_cols", "nnz", + "base_ms", "opt_ms", "pt_ms", "cu_ms", + "opt_vs_base", "opt_vs_pt", "opt_vs_cu", + "err_base", "err_opt", "status", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fields, extrasaction="ignore") + w.writeheader() + for row in all_rows: + w.writerow({k: ("" if v is None else v) for k, v in row.items()}) + print(f"\nWrote {len(all_rows)} rows to {csv_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="SpMV opt A/B: baseline vs optimised, with PyTorch/cuSPARSE." + ) + parser.add_argument("mtx", nargs="*", help=".mtx files or directories") + parser.add_argument("--csv", type=str, default=None, metavar="FILE", + help="Export all dtypes to CSV") + parser.add_argument("--dtype", default="float32", + choices=["float32", "float64"]) + parser.add_argument("--warmup", type=int, default=WARMUP) + parser.add_argument("--iters", type=int, default=ITERS) + args = parser.parse_args() + + paths = [] + for p in args.mtx: + if os.path.isfile(p) and p.endswith(".mtx"): + paths.append(p) + elif os.path.isdir(p): + paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + if not paths: + print("No .mtx files. Usage: python test_spmv_opt.py [--csv out.csv]") + return + + if args.csv: + print("=" * 80) + print("FLAGSPARSE SpMV Optimisation A/B Test — all dtypes, export CSV") + print("=" * 80) + print(f"GPU: {torch.cuda.get_device_name(0)} | Files: {len(paths)} | CSV: {args.csv}") + run_all_csv(paths, args.csv, args.warmup, args.iters) + return + + dtype_map = {"float32": torch.float32, "float64": torch.float64} + dtype = dtype_map[args.dtype] + dname = str(dtype).replace("torch.", "") + print("=" * 170) + print(f"FLAGSPARSE SpMV Optimisation A/B Test") + print(f"GPU: {torch.cuda.get_device_name(0)} | dtype: {dname} | Files: {len(paths)}") + print( + "Base = FlagSparse baseline (fp64-accum for fp32). " + "Opt = FlagSparse CSR-Vector (fp32/fp64 native accum, wide tiles, few launches). " + "Speedup = Base/Opt or Ref/Opt." + ) + print(SEP) + print(HEADER) + print(SEP) + results = run_batch(paths, dtype, torch.int32, args.warmup, args.iters) + print(SEP) + passed = sum(1 for r in results if r["status"] == "PASS") + print(f"Passed: {passed} / {len(results)}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_spsm.py b/tests/test_spsm.py new file mode 100644 index 0000000..5784b20 --- /dev/null +++ b/tests/test_spsm.py @@ -0,0 +1,529 @@ +"""SpSM tests: synthetic and optional .mtx CSV export (CSR/COO).""" + +import argparse +import csv +import glob +import os + +import torch + +import flagsparse as fs + +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse + from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular +except Exception: + cp = None + cpx_sparse = None + cpx_spsolve_triangular = None + + +FORMATS = ("csr", "coo") +VALUE_DTYPES = (torch.float32, torch.float64) +INDEX_DTYPES = [torch.int32] +CSV_VALUE_DTYPES = [torch.float32, torch.float64] +CSV_INDEX_DTYPES = [torch.int32] +WARMUP = 10 +ITERS = 50 + + +def _dtype_name(dtype): + return str(dtype).replace("torch.", "") + + +def _tol(dtype): + if dtype == torch.float32: + return 1e-4, 1e-3 + return 1e-12, 1e-10 + + +def _fmt_ms(v): + return "N/A" if v is None else f"{v:.4f}" + + +def _fmt_speedup(other_ms, fs_ms): + if other_ms is None or fs_ms is None or fs_ms <= 0: + return "N/A" + return f"{other_ms / fs_ms:.2f}x" + + +def _fmt_err(v): + return "N/A" if v is None else f"{v:.2e}" + + +def _build_triangular_case(n=512, n_rhs=32, value_dtype=torch.float32): + device = torch.device("cuda") + A = torch.tril(torch.randn((n, n), dtype=value_dtype, device=device) * 0.02) + diag = torch.rand((n,), dtype=value_dtype, device=device) + 2.0 + A = A + torch.diag(diag) + coo = A.to_sparse().coalesce() + row = coo.indices()[0].to(torch.int64) + col = coo.indices()[1].to(torch.int64) + data = coo.values().to(value_dtype) + _, order = torch.sort(row * n + col) + row = row[order] + col = col[order] + data = data[order] + nnz_per_row = torch.bincount(row, minlength=n) + indptr = torch.zeros(n + 1, dtype=torch.int64, device=device) + indptr[1:] = torch.cumsum(nnz_per_row, dim=0) + B = torch.randn((n, n_rhs), dtype=value_dtype, device=device).contiguous() + return data, row, col, indptr, B, (n, n) + + +def _csr_to_coo(data, indices, indptr, shape): + n_rows = int(shape[0]) + row = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + col = indices.to(torch.int64) + return data, row, col + + +def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + data_lines = [] + header_info = None + mm_field = "real" + mm_symmetry = "general" + for line in lines: + line = line.strip() + if line.startswith("%%MatrixMarket"): + parts = line.split() + if len(parts) >= 5: + mm_field = parts[3].lower() + mm_symmetry = parts[4].lower() + continue + if line.startswith("%"): + continue + if not header_info and line: + parts = line.split() + n_rows = int(parts[0]) + n_cols = int(parts[1]) + nnz = int(parts[2]) if len(parts) > 2 else 0 + header_info = (n_rows, n_cols, nnz) + continue + if line: + data_lines.append(line) + if header_info is None: + raise ValueError(f"Cannot parse .mtx header: {file_path}") + n_rows, n_cols, nnz = header_info + if n_rows != n_cols: + raise ValueError("SpSM requires square matrices") + + row_maps = [dict() for _ in range(n_rows)] + + def _accum(r, c, v): + row = row_maps[r] + row[c] = row.get(c, 0.0) + v + + for line in data_lines[:nnz]: + parts = line.split() + if len(parts) < 2: + continue + r = int(parts[0]) - 1 + c = int(parts[1]) - 1 + if len(parts) >= 3: + v = float(parts[2]) + elif mm_field == "pattern": + v = 1.0 + else: + continue + _accum(r, c, v) + if mm_symmetry in ("symmetric", "hermitian") and r != c: + _accum(c, r, v) + elif mm_symmetry == "skew-symmetric" and r != c: + _accum(c, r, -v) + + # Force lower-triangular + strong diagonal so triangular solve is well-defined. + for r in range(n_rows): + row = row_maps[r] + lower_row = {} + off_abs_sum = 0.0 + for c, v in row.items(): + if c < r: + lower_row[c] = lower_row.get(c, 0.0) + v + off_abs_sum += abs(v) + lower_row[r] = off_abs_sum + 2.0 + row_maps[r] = lower_row + + cols_s = [] + vals_s = [] + indptr_list = [0] + for r in range(n_rows): + row = row_maps[r] + for c in sorted(row.keys()): + cols_s.append(c) + vals_s.append(row[c]) + indptr_list.append(len(cols_s)) + + data = torch.tensor(vals_s, dtype=dtype, device=device) + indices = torch.tensor(cols_s, dtype=torch.int64, device=device) + indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + + +def _cupy_spsm_ref(data, row, col, indptr, B, shape, fmt="csr"): + if cp is None or cpx_sparse is None or cpx_spsolve_triangular is None: + return None, None, "cupy/cusparse unavailable" + try: + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) + B_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(B.contiguous())) + if fmt == "coo": + row_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(row.to(torch.int64).contiguous())) + col_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.to(torch.int64).contiguous())) + A_cp = cpx_sparse.coo_matrix((data_cp, (row_cp, col_cp)), shape=shape) + else: + idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.to(torch.int64).contiguous())) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.contiguous())) + A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + for _ in range(max(0, int(WARMUP))): + _ = cpx_spsolve_triangular(A_cp, B_cp, lower=True, unit_diagonal=False) + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(max(1, int(ITERS))): + X_cp = cpx_spsolve_triangular(A_cp, B_cp, lower=True, unit_diagonal=False) + c1.record() + c1.synchronize() + ms = cp.cuda.get_elapsed_time(c0, c1) / max(1, int(ITERS)) + X_t = torch.utils.dlpack.from_dlpack(X_cp.toDlpack()).to(B.dtype) + return X_t, ms, None + except Exception as exc: + return None, None, str(exc) + + +def _run_one_spsm_case(data, indices, indptr, shape, value_dtype, index_dtype, n_rhs, fmt): + n_rows, _ = shape + B = torch.randn((n_rows, n_rhs), dtype=value_dtype, device=data.device).contiguous() + atol, rtol = _tol(value_dtype) + + if fmt == "csr": + X_fs, fs_ms = fs.flagsparse_spsm_csr( + data=data, + indices=indices.to(index_dtype), + indptr=indptr, + B=B, + shape=shape, + alpha=1.0, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + return_time=True, + ) + data_c, row_c, col_c = _csr_to_coo(data, indices, indptr, shape) + else: + data_c, row_c, col_c = _csr_to_coo(data, indices, indptr, shape) + X_fs, fs_ms = fs.flagsparse_spsm_coo( + data=data_c, + row=row_c.to(index_dtype), + col=col_c.to(index_dtype), + B=B, + shape=shape, + alpha=1.0, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + return_time=True, + ) + + X_cu, cu_ms, cu_reason = _cupy_spsm_ref( + data, row_c, col_c, indptr, B, shape, fmt=fmt + ) + ok_cu = None + err_cu = None + if X_cu is not None: + ok_cu = torch.allclose(X_fs, X_cu, atol=atol, rtol=rtol) + err_cu = float(torch.max(torch.abs(X_fs - X_cu)).item()) if X_fs.numel() else 0.0 + + if ok_cu is None: + status = "SKIP" + elif ok_cu: + status = "PASS" + else: + status = "FAIL" + + note_parts = [] + if cu_reason: + note_parts.append(cu_reason) + + return { + "fmt": fmt, + "n_rows": int(shape[0]), + "n_cols": int(shape[1]), + "nnz": int(data.numel()), + "rhs": int(n_rhs), + "flagsparse_ms": fs_ms, + "cusparse_ms": cu_ms, + "fs_vs_cu": _fmt_speedup(cu_ms, fs_ms), + "status": status, + "err_cu": err_cu, + "note": " | ".join(note_parts), + } + + +def run_spsm_synthetic_all(n=512, n_rhs=32): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + total = 0 + failed = 0 + print("=" * 120) + print("FLAGSPARSE SpSM (NON_TRANS/NON_TRANS, row-major) synthetic test") + print("=" * 120) + print( + f"{'Fmt':>5} {'dtype':>9} {'index':>7} {'N':>6} {'RHS':>6} {'NNZ':>10} " + f"{'FS(ms)':>10} {'CU(ms)':>10} {'FS/CU':>8} " + f"{'Status':>8} {'Err(CU)':>12}" + ) + print("-" * 120) + for fmt in FORMATS: + for value_dtype in VALUE_DTYPES: + for index_dtype in INDEX_DTYPES: + data, row, col, indptr, B, shape = _build_triangular_case( + n=n, n_rhs=n_rhs, value_dtype=value_dtype + ) + atol, rtol = _tol(value_dtype) + if fmt == "csr": + X_fs, fs_ms = fs.flagsparse_spsm_csr( + data=data, + indices=col.to(index_dtype), + indptr=indptr, + B=B, + shape=shape, + alpha=1.0, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + return_time=True, + ) + else: + X_fs, fs_ms = fs.flagsparse_spsm_coo( + data=data, + row=row.to(index_dtype), + col=col.to(index_dtype), + B=B, + shape=shape, + alpha=1.0, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + return_time=True, + ) + + X_cu, cu_ms, cu_reason = _cupy_spsm_ref( + data, row, col, indptr, B, shape, fmt=fmt + ) + ok_cu = None + err_cu = None + if X_cu is not None: + ok_cu = torch.allclose(X_fs, X_cu, atol=atol, rtol=rtol) + err_cu = float(torch.max(torch.abs(X_fs - X_cu)).item()) if X_fs.numel() else 0.0 + + if ok_cu is None: + status = "SKIP" + elif ok_cu: + status = "PASS" + else: + status = "FAIL" + total += 1 + if status != "PASS": + failed += 1 + + print( + f"{fmt:>5} {_dtype_name(value_dtype):>9} {_dtype_name(index_dtype):>7} " + f"{shape[0]:>6} {B.shape[1]:>6} {int(data.numel()):>10} " + f"{_fmt_ms(fs_ms):>10} {_fmt_ms(cu_ms):>10} {_fmt_speedup(cu_ms, fs_ms):>8} " + f"{status:>8} {_fmt_err(err_cu):>12}" + ) + if cu_reason is not None: + print(f" NOTE: {cu_reason}") + print("-" * 120) + print(f"Total cases: {total} Failed: {failed}") + print("=" * 120) + + +def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + rows_out = [] + fmt = "coo" if use_coo else "csr" + + print("=" * 132) + print(f"FLAGSPARSE SpSM .mtx batch ({fmt.upper()}), NON_TRANS/NON_TRANS, row-major") + print("=" * 132) + print( + f"{'Matrix':<28} {'dtype':>9} {'index':>7} {'N':>7} {'RHS':>6} {'NNZ':>10} " + f"{'FS(ms)':>10} {'CU(ms)':>10} {'FS/CU':>8} " + f"{'Status':>8} {'Err(CU)':>12}" + ) + print("-" * 132) + + for value_dtype in CSV_VALUE_DTYPES: + for index_dtype in CSV_INDEX_DTYPES: + for path in mtx_paths: + base = { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + } + try: + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device + ) + one = _run_one_spsm_case( + data, + indices, + indptr, + shape, + value_dtype, + index_dtype, + n_rhs, + fmt, + ) + row = { + **base, + **one, + } + rows_out.append(row) + short = base["matrix"][:27] + ("…" if len(base["matrix"]) > 27 else "") + print( + f"{short:<28} {base['value_dtype']:>9} {base['index_dtype']:>7} " + f"{row['n_rows']:>7} {row['rhs']:>6} {row['nnz']:>10} " + f"{_fmt_ms(row['flagsparse_ms']):>10} {_fmt_ms(row['cusparse_ms']):>10} " + f"{row['fs_vs_cu']:>8} {row['status']:>8} " + f"{_fmt_err(row['err_cu']):>12}" + ) + if row.get("note"): + print(f" NOTE: {row['note']}") + except Exception as exc: + row = { + **base, + "fmt": fmt, + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "rhs": int(n_rhs), + "flagsparse_ms": None, + "cusparse_ms": None, + "fs_vs_cu": "N/A", + "status": "ERROR", + "err_cu": None, + "note": str(exc), + } + rows_out.append(row) + short = base["matrix"][:27] + ("…" if len(base["matrix"]) > 27 else "") + print( + f"{short:<28} {base['value_dtype']:>9} {base['index_dtype']:>7} " + f"{'ERR':>7} {int(n_rhs):>6} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} " + f"{'N/A':>8} {'ERROR':>8} " + f"{_fmt_err(None):>12}" + ) + print(f" ERROR: {exc}") + + print("-" * 132) + fieldnames = [ + "matrix", + "value_dtype", + "index_dtype", + "fmt", + "n_rows", + "n_cols", + "nnz", + "rhs", + "flagsparse_ms", + "cusparse_ms", + "fs_vs_cu", + "status", + "err_cu", + "note", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="SpSM test: synthetic and optional .mtx CSV export (CSR/COO)." + ) + parser.add_argument( + "mtx", + nargs="*", + help=".mtx file path(s), or directory(ies) to glob for *.mtx", + ) + parser.add_argument( + "--synthetic", action="store_true", help="Run synthetic triangular tests" + ) + parser.add_argument("--n", type=int, default=512, help="matrix size (synthetic)") + parser.add_argument("--rhs", type=int, default=32, help="number of RHS columns") + parser.add_argument( + "--csv-csr", + type=str, + default=None, + metavar="FILE", + help="Run .mtx batch in CSR mode and export CSV", + ) + parser.add_argument( + "--csv-coo", + type=str, + default=None, + metavar="FILE", + help="Run .mtx batch in COO mode and export CSV", + ) + args = parser.parse_args() + + if args.synthetic: + run_spsm_synthetic_all(n=args.n, n_rhs=args.rhs) + return + + paths = [] + for p in args.mtx: + if os.path.isfile(p) and p.endswith(".mtx"): + paths.append(p) + elif os.path.isdir(p): + paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + + if args.csv_csr: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --csv-csr") + return + run_all_dtypes_spsm_csv(paths, args.csv_csr, use_coo=False, n_rhs=args.rhs) + return + + if args.csv_coo: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --csv-coo") + return + run_all_dtypes_spsm_csv(paths, args.csv_coo, use_coo=True, n_rhs=args.rhs) + return + + print("Use --synthetic, --csv-csr, or --csv-coo to run SpSM tests.") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/test_spsv.py b/tests/test_spsv.py new file mode 100644 index 0000000..8853dbd --- /dev/null +++ b/tests/test_spsv.py @@ -0,0 +1,1932 @@ +"""SpSV tests: synthetic triangular systems and optional .mtx (CSR/COO). + +与 CSR 相同的计时列、PyTorch 稠密参考、CSV 字段与 PASS 判定;COO 测试时 CuPy 基线使用 +``coo_matrix``(与 FlagSparse COO 输入同构),CSR 测试时仍用 ``csr_matrix``。 +""" + +import argparse +import csv +import glob +import hashlib +import os +import sys +from pathlib import Path + +import torch + +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + +import flagsparse as fs +import flagsparse.sparse_operations.spsv as fs_spsv_impl + +try: + import cupy as cp + import cupyx.scipy.sparse as cpx_sparse + from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular +except Exception: + cp = None + cpx_sparse = None + cpx_spsolve_triangular = None + +VALUE_DTYPES = [torch.float32, torch.float64, torch.complex64, torch.complex128] +INDEX_DTYPES = [torch.int32, torch.int64] +TEST_SIZES = [256, 512, 1024, 2048] +WARMUP = 5 +ITERS = 20 + +DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB +SPSV_TRIANGULAR_DIAG_DOMINANCE = 4.0 +# CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) +CSR_FULL_VALUE_DTYPES = [ + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, +] +CSR_FULL_INDEX_DTYPES = [torch.int32, torch.int64] +SPSV_OP_MODES = ["NON", "TRANS", "CONJ"] + + +def _dtype_name(dtype): + return str(dtype).replace("torch.", "") + + +VALUE_DTYPE_NAME_MAP = { + _dtype_name(dtype): dtype for dtype in CSR_FULL_VALUE_DTYPES +} +VALUE_DTYPE_NAME_MAP.update({ + "float": torch.float32, + "double": torch.float64, +}) +INDEX_DTYPE_NAME_MAP = { + _dtype_name(dtype): dtype for dtype in CSR_FULL_INDEX_DTYPES +} + + +def _parse_csv_tokens(raw): + return [tok.strip() for tok in str(raw).split(",") if tok.strip()] + + +def _parse_value_dtypes_filter(raw): + tokens = [tok.lower() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in VALUE_DTYPE_NAME_MAP] + if invalid: + raise ValueError(f"unsupported value dtypes: {invalid}") + return [VALUE_DTYPE_NAME_MAP[tok] for tok in tokens] + + +def _parse_index_dtypes_filter(raw): + tokens = [tok.lower() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in INDEX_DTYPE_NAME_MAP] + if invalid: + raise ValueError(f"unsupported index dtypes: {invalid}") + return [INDEX_DTYPE_NAME_MAP[tok] for tok in tokens] + + +def _parse_op_modes_filter(raw): + tokens = [tok.upper() for tok in _parse_csv_tokens(raw)] + invalid = [tok for tok in tokens if tok not in SPSV_OP_MODES] + if invalid: + raise ValueError(f"unsupported ops: {invalid}") + return tokens + + +def _fmt_ms(v): + return "N/A" if v is None else f"{v:.4f}" + + +def _fmt_speedup(other_ms, triton_ms): + if other_ms is None or triton_ms is None or triton_ms <= 0: + return "N/A" + return f"{other_ms / triton_ms:.2f}x" + + +def _fmt_err(v): + return "N/A" if v is None else f"{v:.2e}" + + +def _safe_ratio(other_ms, base_ms): + if other_ms is None or base_ms is None or base_ms <= 0: + return None + return other_ms / base_ms + + +def _sum_ms(*values): + if any(v is None for v in values): + return None + return sum(values) + + +def _tol_for_dtype(dtype): + if dtype in (torch.float32, torch.complex64): + return 1e-4, 1e-2 + return 1e-12, 1e-10 + + +def _stable_case_seed(*parts): + raw = "|".join(str(part) for part in parts).encode("utf-8") + return int.from_bytes(hashlib.sha256(raw).digest()[:8], "little") % (2**63) + + +def _generator_for_seed(seed): + if seed is None: + return None + gen = torch.Generator() + gen.manual_seed(int(seed)) + return gen + + +def _randn_by_dtype(n, dtype, device, generator=None): + if dtype in (torch.float32, torch.float64): + return torch.randn(n, dtype=dtype, device=device, generator=generator) + base = torch.float32 if dtype == torch.complex64 else torch.float64 + real = torch.randn(n, dtype=base, device=device, generator=generator) + imag = torch.randn(n, dtype=base, device=device, generator=generator) + return torch.complex(real, imag) + + +def _dense_ref_dtype(dtype): + return dtype + + +def _tensor_from_scalar_values(values, dtype, device): + return torch.tensor(values, dtype=dtype, device=device) + + +def _safe_cast_tensor(tensor, dtype): + return tensor.to(dtype) + + +def _cast_real_tensor_to_value_dtype(values, value_dtype): + return values.to(value_dtype) + + +def _matrix_market_value(parts, mm_field): + if mm_field == "complex": + if len(parts) < 4: + raise ValueError("MatrixMarket complex entry requires real and imag parts") + return complex(float(parts[2]), float(parts[3])) + if len(parts) >= 3: + return float(parts[2]) + if mm_field == "pattern": + return 1.0 + raise ValueError("MatrixMarket entry is missing a numeric value") + + +def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): + if op_mode == "TRANS": + A_eff = A.transpose(0, 1) + upper = lower + elif op_mode == "CONJ": + A_eff = A.transpose(0, 1).conj() if torch.is_complex(A) else A.transpose(0, 1) + upper = lower + else: + A_eff = A + upper = not lower + return torch.linalg.solve_triangular( + A_eff, b.unsqueeze(1), upper=upper + ).squeeze(1) + + +def _cupy_ref_inputs(data, b): + return data, b + + +def _compare_view(tensor, value_dtype): + return tensor + + +def _supported_csr_full_ops(value_dtype, index_dtype): + if value_dtype not in CSR_FULL_VALUE_DTYPES: + return [] + if index_dtype == torch.int32: + return ["NON", "TRANS", "CONJ"] + if index_dtype == torch.int64: + return ["NON", "TRANS", "CONJ"] + return [] + + +def _fill_mode_name(lower): + return "LOWER" if lower else "UPPER" + + +def _diag_type_name(unit_diagonal): + return "UNIT" if unit_diagonal else "NON_UNIT" + + +def _triton_alg_name(fmt, op_mode=None, coo_mode=None): + if fmt == "CSR": + if op_mode in ("TRANS", "CONJ"): + return "TRITON_CSR_LEVEL_TRANSPOSE_FAMILY" + return "TRITON_CSR_LEVEL" + if coo_mode: + return f"TRITON_COO_{str(coo_mode).upper()}" + return "TRITON_COO_LEVEL" + + +def _cusparse_alg_name(): + return "CUPY_SPSOLVE_TRIANGULAR" + + +def _allow_dense_pytorch_ref(shape, dtype): + n_rows, n_cols = int(shape[0]), int(shape[1]) + elem_bytes = torch.empty((), dtype=dtype).element_size() + dense_bytes = n_rows * n_cols * elem_bytes + return dense_bytes <= DENSE_REF_MAX_BYTES + + +def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True): + """Build a well-conditioned triangular CSR for real and complex dtypes.""" + max_bandwidth = max(4, min(n, 16)) + rows_host = [] + cols_host = [] + vals_host = [] + row_off_abs = [0.0] * n + col_off_abs = [0.0] * n + if value_dtype == torch.float32: + base_real_dtype = torch.float32 + elif value_dtype == torch.float64: + base_real_dtype = torch.float64 + elif value_dtype == torch.complex64: + base_real_dtype = torch.float32 + else: + base_real_dtype = torch.float64 + + for i in range(n): + if lower: + cand_cols = list(range(0, i + 1)) + else: + cand_cols = list(range(i, n)) + if not cand_cols: + cand_cols = [i] + diag_col = i + off_cand = [c for c in cand_cols if c != diag_col] + k_off = min(len(off_cand), max_bandwidth - 1) + if k_off > 0: + perm = torch.randperm(len(off_cand))[:k_off].tolist() + off_cols = [off_cand[j] for j in perm] + else: + off_cols = [] + if value_dtype in (torch.complex64, torch.complex128): + off_vals = torch.complex( + torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), + torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01), + ) + off_vals_host = [complex(v) for v in off_vals.cpu().tolist()] + else: + off_vals = torch.randn(len(off_cols), dtype=base_real_dtype, device=device).mul_(0.01) + off_vals_host = off_vals.cpu().tolist() + for c, v in zip(off_cols, off_vals_host): + rows_host.append(i) + cols_host.append(int(c)) + vals_host.append(v) + mag = abs(v) + row_off_abs[i] += mag + col_off_abs[int(c)] += mag + + for i in range(n): + diag_mag = ( + SPSV_TRIANGULAR_DIAG_DOMINANCE * max(row_off_abs[i], col_off_abs[i]) + 1.0 + ) + diag_val = ( + complex(diag_mag, 0.0) + if value_dtype in (torch.complex64, torch.complex128) + else diag_mag + ) + rows_host.append(i) + cols_host.append(i) + vals_host.append(diag_val) + + rows_t = torch.tensor(rows_host, dtype=torch.int64, device=device) + cols_t = torch.tensor(cols_host, dtype=torch.int64, device=device) + vals_t = torch.tensor(vals_host, dtype=value_dtype, device=device) + order = torch.argsort(rows_t * max(1, n) + cols_t) + rows_t = rows_t[order] + cols_t = cols_t[order] + vals_t = vals_t[order] + nnz_per_row = torch.bincount(rows_t, minlength=n) + indptr = torch.zeros(n + 1, dtype=torch.int64, device=device) + indptr[1:] = torch.cumsum(nnz_per_row, dim=0) + indices = cols_t.to(index_dtype) + return vals_t, indices, indptr, (n, n) + + +def _csr_to_dense(data, indices, indptr, shape): + n_rows, n_cols = shape + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + coo = torch.sparse_coo_tensor( + torch.stack([row_ind, indices.to(torch.int64)]), + data, + (n_rows, n_cols), + device=data.device, + ).coalesce() + return coo.to_dense() + + +def _csr_to_coo(data, indices, indptr, shape, index_dtype=torch.int64): + n_rows = int(shape[0]) + row = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=index_dtype), + indptr[1:] - indptr[:-1], + ) + col = indices.to(index_dtype) + return data, row, col + + +def _csr_transpose(data, indices, indptr, shape, conjugate=False): + n_rows, n_cols = int(shape[0]), int(shape[1]) + if data.numel() == 0: + return ( + data, + torch.empty(0, dtype=torch.int64, device=data.device), + torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device), + ) + + row, col = _csr_to_coo(data, indices, indptr, shape)[1:] + row_t = col + col_t = row + key = row_t * max(1, n_rows) + col_t + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + + row_t = row_t[order] + col_t = col_t[order] + data_eff = data.conj() if conjugate and torch.is_complex(data) else data + data_t = data_eff[order] + nnz_per_row = torch.bincount(row_t, minlength=n_cols) + indptr_t = torch.zeros(n_cols + 1, dtype=torch.int64, device=data.device) + indptr_t[1:] = torch.cumsum(nnz_per_row, dim=0) + return data_t, col_t.to(torch.int64), indptr_t + + +def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None, lower=True): + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + with open(file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + data_lines = [] + header_info = None + mm_field = "real" + mm_symmetry = "general" + for line in lines: + line = line.strip() + if line.startswith("%%MatrixMarket"): + parts = line.split() + if len(parts) >= 5: + mm_field = parts[3].lower() + mm_symmetry = parts[4].lower() + continue + if line.startswith("%"): + continue + if not header_info and line: + parts = line.split() + n_rows = int(parts[0]) + n_cols = int(parts[1]) + nnz = int(parts[2]) if len(parts) > 2 else 0 + header_info = (n_rows, n_cols, nnz) + continue + if line: + data_lines.append(line) + if header_info is None: + raise ValueError(f"Cannot parse .mtx header: {file_path}") + n_rows, n_cols, nnz = header_info + if n_rows != n_cols: + raise ValueError("SpSV requires square matrices") + row_maps = [dict() for _ in range(n_rows)] + + def _accum(r, c, v): + row = row_maps[r] + row[c] = row.get(c, 0.0) + v + + for line in data_lines[:nnz]: + parts = line.split() + if len(parts) < 2: + continue + r = int(parts[0]) - 1 + c = int(parts[1]) - 1 + v = _matrix_market_value(parts, mm_field) + _accum(r, c, v) + if mm_symmetry == "symmetric" and r != c: + _accum(c, r, v) + elif mm_symmetry == "hermitian" and r != c: + _accum(c, r, v.conjugate() if isinstance(v, complex) else v) + elif mm_symmetry == "skew-symmetric" and r != c: + _accum(c, r, -v) + + tri_rows = [dict() for _ in range(n_rows)] + row_off_abs = [0.0] * n_rows + col_off_abs = [0.0] * n_cols + for r in range(n_rows): + for c, v in row_maps[r].items(): + keep = c < r if lower else c > r + if keep: + tri_rows[r][c] = tri_rows[r].get(c, 0.0) + v + + for r, row in enumerate(tri_rows): + for c, v in row.items(): + mag = abs(v) + row_off_abs[r] += mag + col_off_abs[c] += mag + + for r in range(n_rows): + # Make the generated triangular system stable for both A and op(A). + tri_rows[r][r] = ( + SPSV_TRIANGULAR_DIAG_DOMINANCE * max(row_off_abs[r], col_off_abs[r]) + 1.0 + ) + row_maps = tri_rows + + cols_s = [] + vals_s = [] + indptr_list = [0] + for r in range(n_rows): + row = row_maps[r] + for c in sorted(row.keys()): + cols_s.append(c) + vals_s.append(row[c]) + indptr_list.append(len(cols_s)) + data = _tensor_from_scalar_values(vals_s, dtype, device) + indices = torch.tensor(cols_s, dtype=torch.int64, device=device) + indptr = torch.tensor(indptr_list, dtype=torch.int64, device=device) + return data, indices, indptr, (n_rows, n_cols) + + +def _coo_inputs_for_csv(data, indices, indptr, shape, coo_mode, index_dtype=torch.int64): + """Sorted COO from CSR; optional shuffle/duplicate for csr|auto (与原先 CSV 行为一致).""" + data_c, row_c, col_c = _csr_to_coo( + data, indices, indptr, shape, index_dtype=index_dtype + ) + if coo_mode in ("csr", "auto"): + if data_c.numel() == 0: + return data_c, row_c, col_c + if data_c.numel() <= 2_000_000: + left = data_c * 0.25 + right = data_c - left + data_dup = torch.cat([left, right], dim=0) + row_dup = torch.cat([row_c, row_c], dim=0) + col_dup = torch.cat([col_c, col_c], dim=0) + perm = torch.randperm(data_dup.numel(), device=data_c.device) + return data_dup[perm], row_dup[perm], col_dup[perm] + perm = torch.randperm(data_c.numel(), device=data_c.device) + return data_c[perm], row_c[perm], col_c[perm] + return data_c, row_c, col_c + + +def _random_rhs_for_spsv(shape, value_dtype, device, op_mode="NON", seed=None): + n_rows, n_cols = int(shape[0]), int(shape[1]) + rhs_size = n_rows if op_mode == "NON" else n_cols + if seed is None: + return _randn_by_dtype(rhs_size, value_dtype, device) + rhs = _randn_by_dtype( + rhs_size, + value_dtype, + torch.device("cpu"), + generator=_generator_for_seed(seed), + ) + return rhs.to(device) + + +def _apply_csr_op(data, indices, indptr, x, shape, op_mode): + n_rows, n_cols = int(shape[0]), int(shape[1]) + row = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr.to(torch.int64)[1:] - indptr.to(torch.int64)[:-1], + ) + col = indices.to(torch.int64) + if op_mode == "NON": + b = torch.zeros(n_rows, dtype=data.dtype, device=data.device) + b.scatter_add_(0, row, data * x[col]) + return b + if op_mode == "TRANS": + b = torch.zeros(n_cols, dtype=data.dtype, device=data.device) + b.scatter_add_(0, col, data * x[row]) + return b + if op_mode == "CONJ": + b = torch.zeros(n_cols, dtype=data.dtype, device=data.device) + data_eff = data.conj() if torch.is_complex(data) else data + b.scatter_add_(0, col, data_eff * x[row]) + return b + raise ValueError("op_mode must be 'NON', 'TRANS', or 'CONJ'") + + +def _solution_residual_metrics(data, indices, indptr, shape, x, b, value_dtype, op_mode): + atol, rtol = _tol_for_dtype(value_dtype) + b_recon = _apply_csr_op(data, indices, indptr, x, shape, op_mode) + err_res = ( + float(torch.max(torch.abs(b_recon - b)).item()) + if b.numel() > 0 + else 0.0 + ) + ok_res = torch.allclose(b_recon, b, atol=atol, rtol=rtol) + return err_res, ok_res + + +def _benchmark_flagsparse(call): + x = None + for _ in range(WARMUP): + x = call() + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + for _ in range(ITERS): + x = call() + e1.record() + torch.cuda.synchronize() + return x, e0.elapsed_time(e1) / ITERS + + +def _benchmark_flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + *, + lower=True, + transpose=False, +): + return _benchmark_flagsparse( + lambda: fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=transpose, + ) + ) + + +def _benchmark_flagsparse_spsv_csr_split( + data, + indices, + indptr, + b, + shape, + *, + lower=True, + transpose=False, +): + analysis_ms = fs_spsv_impl._analyze_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=transpose, + clear_cache=True, + return_time=True, + ) + x, solve_ms = _benchmark_flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=transpose, + ) + return x, analysis_ms, solve_ms + + +def _benchmark_flagsparse_spsv_coo( + data, + row, + col, + b, + shape, + *, + lower=True, + coo_mode="auto", +): + return _benchmark_flagsparse( + lambda: fs.flagsparse_spsv_coo( + data, + row, + col, + b, + shape, + lower=lower, + coo_mode=coo_mode, + ) + ) + + +def _cupy_spsolve_lower_csr_or_coo( + fmt, + data, + indices, + indptr, + shape, + b, + warmup, + iters, + lower, +): + """Triangular solve via CuPy: CSR or COO storage. Returns (ms, x_torch) or (None, None).""" + if ( + cp is None + or cpx_sparse is None + or cpx_spsolve_triangular is None + ): + return None, None + try: + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b.contiguous())) + if fmt == "COO": + dc, rr, cc = _csr_to_coo(data, indices, indptr, shape) + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(dc.contiguous())) + row_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(rr.to(torch.int64).contiguous()) + ) + col_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(cc.to(torch.int64).contiguous()) + ) + A_cp = cpx_sparse.coo_matrix((data_cp, (row_cp, col_cp)), shape=shape) + else: + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) + idx_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous()) + ) + ptr_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indptr.contiguous()) + ) + A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + for _ in range(warmup): + _ = cpx_spsolve_triangular( + A_cp, b_cp, lower=lower, unit_diagonal=False + ) + cp.cuda.runtime.deviceSynchronize() + t0 = cp.cuda.Event() + t1 = cp.cuda.Event() + t0.record() + for _ in range(iters): + x_cu = cpx_spsolve_triangular( + A_cp, b_cp, lower=lower, unit_diagonal=False + ) + t1.record() + t1.synchronize() + cupy_ms = cp.cuda.get_elapsed_time(t0, t1) / iters + x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) + x_cu_t = x_cu_t.to(b.dtype) + return cupy_ms, x_cu_t + except Exception: + return None, None + + +def _cupy_spsolve_csr_with_op(data, indices, indptr, shape, b, op_mode, lower): + if ( + cp is None + or cpx_sparse is None + or cpx_spsolve_triangular is None + ): + return None, None + try: + data_ref, b_ref = _cupy_ref_inputs(data, b) + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data_ref.contiguous())) + idx_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indices.to(torch.int64).contiguous()) + ) + ptr_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indptr.to(torch.int64).contiguous()) + ) + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) + A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + if op_mode == "TRANS": + A_eff = A_cp.transpose().tocsr() + lower_eff = not lower + elif op_mode == "CONJ": + A_eff = A_cp.transpose().conj().tocsr() + lower_eff = not lower + else: + A_eff = A_cp + lower_eff = lower + + for _ in range(WARMUP): + _ = cpx_spsolve_triangular( + A_eff, b_cp, lower=lower_eff, unit_diagonal=False + ) + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(ITERS): + x_cp = cpx_spsolve_triangular( + A_eff, b_cp, lower=lower_eff, unit_diagonal=False + ) + c1.record() + c1.synchronize() + ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + x_t = torch.utils.dlpack.from_dlpack(x_cp.toDlpack()).to(b.dtype) + return ms, x_t + except Exception: + return None, None + + +def run_spsv_synthetic_all(lower=True): + if not torch.cuda.is_available(): + print("CUDA is not available. Please run on a GPU-enabled system.") + return + device = torch.device("cuda") + sep = "=" * 110 + print(sep) + print("FLAGSPARSE SpSV BENCHMARK (synthetic triangular systems, CSR + COO)") + print(sep) + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"Warmup: {WARMUP} | Iters: {ITERS}") + print(f"Triangle: {'LOWER' if lower else 'UPPER'}") + print() + + hdr = ( + f"{'Fmt':>5} {'opA':>5} {'N':>6} {'FS.analysis':>12} {'FS.solve(ms)':>14} {'FS.total':>10} " + f"{'PyTorch(ms)':>12} {'cuSPARSE':>10} {'pt/triton':>10} {'cu/triton':>10} {'Status':>8} {'Err(PT)':>12} {'Err(CU)':>12}" + ) + + total = 0 + failed = 0 + for value_dtype in VALUE_DTYPES: + for index_dtype in INDEX_DTYPES: + print("-" * 110) + print( + f"Value dtype: {_dtype_name(value_dtype):<12} | " + f"Index dtype: {_dtype_name(index_dtype):<6}" + ) + print("-" * 110) + print(hdr) + print("-" * 110) + for n in TEST_SIZES: + for fmt in ("CSR", "COO"): + op_modes = ( + _supported_csr_full_ops(value_dtype, index_dtype) + if fmt == "CSR" + else ["NON"] + ) + for op_mode in op_modes: + data, indices, indptr, shape = _build_random_triangular_csr( + n, value_dtype, index_dtype, device, lower=lower + ) + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr, shape + ) + rhs_op = op_mode if fmt == "CSR" else "NON" + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=rhs_op, + seed=_stable_case_seed( + "synthetic", + "LOWER" if lower else "UPPER", + fmt, + op_mode, + n, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + + torch.cuda.synchronize() + if fmt == "CSR": + x, analysis_ms, t_ms = _benchmark_flagsparse_spsv_csr_split( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=op_mode, + ) + else: + analysis_ms = None + dc, rr, cc = _csr_to_coo( + data, indices, indptr, shape, index_dtype=index_dtype + ) + x, t_ms = _benchmark_flagsparse_spsv_coo( + dc, + rr, + cc, + b, + shape, + lower=lower, + coo_mode="auto", + ) + torch.cuda.synchronize() + + A_ref = A_dense + b_ref = b + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_pt = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode=op_mode + ) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = float(torch.max(torch.abs(x - x_pt)).item()) if n > 0 else 0.0 + + cupy_ms = None + err_cu = None + x_cu_t = None + if fmt == "CSR": + cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( + data, indices, indptr, shape, b, op_mode, lower + ) + elif value_dtype in ( + torch.float32, + torch.float64, + torch.complex64, + torch.complex128, + ): + cupy_ms, x_cu_t = _cupy_spsolve_lower_csr_or_coo( + fmt, + data, + indices, + indptr, + shape, + b, + WARMUP, + ITERS, + lower, + ) + if x_cu_t is not None and n > 0: + err_cu = float( + torch.max(torch.abs(x - x_cu_t)).item() + ) + + atol, rtol = _tol_for_dtype(value_dtype) + ok_pt = torch.allclose(x, x_pt, atol=atol, rtol=rtol) + ok_cu = ( + True + if x_cu_t is None + else torch.allclose(x, x_cu_t, atol=atol, rtol=rtol) + ) + ok = ok_pt or ok_cu + status = "PASS" if ok else "FAIL" + if not ok: + failed += 1 + total += 1 + + fs_vs_pt = ( + (pytorch_ms / t_ms) if (t_ms and t_ms > 0) else None + ) + fs_vs_cu = ( + (cupy_ms / t_ms) + if (cupy_ms is not None and t_ms and t_ms > 0) + else None + ) + print( + f"{fmt:>5} {op_mode:>5} {n:>6} {_fmt_ms(analysis_ms):>12} {_fmt_ms(t_ms):>14} {_fmt_ms(_sum_ms(analysis_ms, t_ms)):>10} {_fmt_ms(pytorch_ms):>12} " + f"{_fmt_ms(cupy_ms):>10} " + f"{(f'{fs_vs_pt:.2f}x' if fs_vs_pt is not None else 'N/A'):>10} " + f"{(f'{fs_vs_cu:.2f}x' if fs_vs_cu is not None else 'N/A'):>10} " + f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" + ) + print("-" * 110) + print() + + print(sep) + print(f"Total cases: {total} Failed: {failed}") + print(sep) + + +def _run_one_csv_row_coo(path, value_dtype, index_dtype, device, coo_mode, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower + ) + indices = indices.to(index_dtype) + n_rows, n_cols = shape + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode="NON", + seed=_stable_case_seed( + "csv-coo", + os.path.basename(path), + "LOWER" if lower else "UPPER", + coo_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + d_in, r_in, c_in = _coo_inputs_for_csv( + data, indices, indptr, shape, coo_mode, index_dtype=index_dtype + ) + analysis_ms = None + x, t_ms = _benchmark_flagsparse_spsv_coo( + d_in, + r_in, + c_in, + b, + shape, + lower=lower, + coo_mode=coo_mode, + ) + return _finalize_csv_row( + path, + value_dtype, + index_dtype, + data, + indices, + indptr, + shape, + x, + analysis_ms, + t_ms, + b, + n_rows, + n_cols, + lower=lower, + coo_mode=coo_mode, + nnz_display=int(d_in.numel()), + cupy_coo_data=d_in, + cupy_coo_row=r_in, + cupy_coo_col=c_in, + ) + + +def _finalize_csv_row( + path, + value_dtype, + index_dtype, + data, + indices, + indptr, + shape, + x, + analysis_ms, + t_ms, + b, + n_rows, + n_cols, + *, + lower=True, + coo_mode="auto", + nnz_display=None, + cupy_coo_data=None, + cupy_coo_row=None, + cupy_coo_col=None, +): + atol, rtol = _tol_for_dtype(value_dtype) + err_res, _ = _solution_residual_metrics( + data, indices, indptr, shape, x, b, value_dtype, "NON" + ) + pytorch_ms = None + err_pt = None + ok_pt = False + pt_skip_reason = None + if _allow_dense_pytorch_ref(shape, value_dtype): + try: + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr, shape + ) + ref_dtype = _dense_ref_dtype(value_dtype) + A_ref = A_dense.to(ref_dtype) + b_ref = b.to(ref_dtype) + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + torch.cuda.synchronize() + e0.record() + x_ref = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode="NON" + ) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = ( + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pt_skip_reason = "PyTorch dense ref OOM; skipped" + else: + raise + else: + pt_skip_reason = ( + f"PyTorch dense ref skipped (> {DENSE_REF_MAX_BYTES // (1024**3)} GiB dense matrix)" + ) + + cupy_ms = None + err_cu = None + ok_cu = False + x_cu_t = None + if ( + cp is not None + and cpx_sparse is not None + and cpx_spsolve_triangular is not None + ): + try: + data_ref, b_ref = _cupy_ref_inputs(data, b) + b_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(b_ref.contiguous())) + if ( + cupy_coo_data is not None + and cupy_coo_row is not None + and cupy_coo_col is not None + ): + coo_data_ref, _ = _cupy_ref_inputs(cupy_coo_data, b) + data_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(coo_data_ref.contiguous()) + ) + row_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack( + cupy_coo_row.to(torch.int64).contiguous() + ) + ) + col_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack( + cupy_coo_col.to(torch.int64).contiguous() + ) + ) + A_cp = cpx_sparse.coo_matrix( + (data_cp, (row_cp, col_cp)), shape=shape + ) + else: + data_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(data_ref.contiguous()) + ) + idx_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack( + indices.to(torch.int64).contiguous() + ) + ) + ptr_cp = cp.from_dlpack( + torch.utils.dlpack.to_dlpack(indptr.contiguous()) + ) + A_cp = cpx_sparse.csr_matrix( + (data_cp, idx_cp, ptr_cp), shape=shape + ) + for _ in range(WARMUP): + _ = cpx_spsolve_triangular( + A_cp, b_cp, lower=lower, unit_diagonal=False + ) + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(ITERS): + x_cu = cpx_spsolve_triangular( + A_cp, b_cp, lower=lower, unit_diagonal=False + ) + c1.record() + c1.synchronize() + cupy_ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) + x_cmp = _compare_view(x, value_dtype) + x_cu_cmp = _compare_view(x_cu_t, value_dtype) + err_cu = ( + float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_cu = torch.allclose(x_cmp, x_cu_cmp, atol=atol, rtol=rtol) + except Exception: + cupy_ms = None + err_cu = None + + status = "PASS" if (ok_pt or ok_cu) else "FAIL" + if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): + status = "REF_FAIL" + ref_errors = [err for err in (err_pt, err_cu) if err is not None] + err_ref = min(ref_errors) if ref_errors else None + + nnz_out = ( + int(data.numel()) if nnz_display is None else int(nnz_display) + ) + row = { + "format": "COO", + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": "NON", + "fill_mode": _fill_mode_name(lower), + "diag_type": _diag_type_name(False), + "triton_alg": _triton_alg_name("COO", op_mode="NON", coo_mode=coo_mode), + "reference_alg": _cusparse_alg_name(), + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": nnz_out, + "triton_analysis_ms": analysis_ms, + "triton_solve_ms": t_ms, + "triton_time_total_ms": _sum_ms(analysis_ms, t_ms), + "pytorch_solve_ms": pytorch_ms, + "cusparse_solve_ms": cupy_ms, + "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), + "cusparse/triton": _safe_ratio(cupy_ms, t_ms), + "status": status, + "err_ref": err_ref, + "err_res": err_res, + "err_pt": err_pt, + "err_cu": err_cu, + } + return row, pt_skip_reason + + +def _run_one_csv_row_csr_full(path, value_dtype, index_dtype, op_mode, device, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower + ) + indices = indices.to(index_dtype) + indptr = indptr.to(index_dtype) + n_rows, n_cols = shape + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=op_mode, + seed=_stable_case_seed( + "csv-csr", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + x, analysis_ms, t_ms = _benchmark_flagsparse_spsv_csr_split( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=op_mode, + ) + return _finalize_csv_row_csr_full( + path, + value_dtype, + index_dtype, + op_mode, + data, + indices, + indptr, + shape, + x, + analysis_ms, + t_ms, + b, + n_rows, + n_cols, + lower=lower, + ) + + +def _finalize_csv_row_csr_full( + path, + value_dtype, + index_dtype, + op_mode, + data, + indices, + indptr, + shape, + x, + analysis_ms, + t_ms, + b, + n_rows, + n_cols, + lower=True, +): + atol, rtol = _tol_for_dtype(value_dtype) + err_res, _ = _solution_residual_metrics( + data, indices, indptr, shape, x, b, value_dtype, op_mode + ) + + pytorch_ms = None + err_pt = None + ok_pt = False + pt_skip_reason = None + if _allow_dense_pytorch_ref(shape, value_dtype): + try: + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + torch.cuda.synchronize() + e0.record() + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = ( + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pt_skip_reason = "PyTorch dense ref OOM; skipped" + else: + raise + else: + pt_skip_reason = ( + f"PyTorch dense ref skipped (> {DENSE_REF_MAX_BYTES // (1024**3)} GiB dense matrix)" + ) + + cupy_ms = None + err_cu = None + ok_cu = False + x_cu_t = None + cupy_ms, x_cu_t = _cupy_spsolve_csr_with_op( + data, indices, indptr, shape, b, op_mode, lower + ) + if x_cu_t is not None: + x_cmp = _compare_view(x, value_dtype) + x_cu_cmp = _compare_view(x_cu_t, value_dtype) + err_cu = ( + float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_cu = torch.allclose(x_cmp, x_cu_cmp, atol=atol, rtol=rtol) + + status = "PASS" if (ok_pt or ok_cu) else "FAIL" + if (not ok_pt) and (not ok_cu) and (err_pt is None and err_cu is None): + status = "REF_FAIL" + ref_errors = [err for err in (err_pt, err_cu) if err is not None] + err_ref = min(ref_errors) if ref_errors else None + + row = { + "format": "CSR", + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "fill_mode": _fill_mode_name(lower), + "diag_type": _diag_type_name(False), + "triton_alg": _triton_alg_name("CSR", op_mode=op_mode), + "reference_alg": _cusparse_alg_name(), + "n_rows": n_rows, + "n_cols": n_cols, + "nnz": int(data.numel()), + "triton_analysis_ms": analysis_ms, + "triton_solve_ms": t_ms, + "triton_time_total_ms": _sum_ms(analysis_ms, t_ms), + "pytorch_solve_ms": pytorch_ms, + "cusparse_solve_ms": cupy_ms, + "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), + "cusparse/triton": _safe_ratio(cupy_ms, t_ms), + "status": status, + "err_ref": err_ref, + "err_res": err_res, + "err_pt": err_pt, + "err_cu": err_cu, + } + return row, pt_skip_reason + + +def run_all_supported_spsv_csr_csv( + mtx_paths, + csv_path, + lower=True, + value_dtypes=None, + index_dtypes=None, + op_modes=None, +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + rows_out = [] + selected_value_dtypes = value_dtypes or CSR_FULL_VALUE_DTYPES + selected_index_dtypes = index_dtypes or CSR_FULL_INDEX_DTYPES + selected_op_modes = op_modes or SPSV_OP_MODES + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + supported_op_modes = [ + op for op in _supported_csr_full_ops(value_dtype, index_dtype) + if op in selected_op_modes + ] + for op_mode in supported_op_modes: + print("=" * 150) + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" + ) + print( + "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." + ) + print( + "RHS is generated directly, matching Library-main's SpSV test style. " + "FlagSparse analysis is measured separately; FlagSparse(ms) below reports solve only. " + "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|op(A)*x-b|, " + "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " + "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FS.anlys':>10} {'FS.solve':>10} {'FS.total':>10} {'cu.solve':>10} {'pt.solve':>10} " + f"{'cu/triton':>10} {'pt/triton':>10} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 150) + for path in mtx_paths: + try: + row, pt_skip = _run_one_csv_row_csr_full( + path, value_dtype, index_dtype, op_mode, device, lower=lower + ) + rows_out.append(row) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + n_rows, n_cols = row["n_rows"], row["n_cols"] + nnz = row["nnz"] + analysis_ms = row["triton_analysis_ms"] + t_ms = row["triton_solve_ms"] + cupy_ms = row["cusparse_solve_ms"] + pytorch_ms = row["pytorch_solve_ms"] + err_ref, err_res = row["err_ref"], row["err_res"] + err_pt, err_cu = row["err_pt"], row["err_cu"] + status = row["status"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " + f"{_fmt_ms(analysis_ms):>10} {_fmt_ms(t_ms):>10} {_fmt_ms(row['triton_time_total_ms']):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(pytorch_ms):>10} " + f"{_fmt_speedup(cupy_ms, t_ms):>10} {_fmt_speedup(pytorch_ms, t_ms):>10} " + f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + ) + if pt_skip: + print(f" NOTE: {pt_skip}") + except Exception as e: + err_msg = str(e) + status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" + rows_out.append( + { + "format": "CSR", + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "fill_mode": _fill_mode_name(lower), + "diag_type": _diag_type_name(False), + "triton_alg": _triton_alg_name("CSR", op_mode=op_mode), + "reference_alg": _cusparse_alg_name(), + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "triton_analysis_ms": None, + "triton_solve_ms": None, + "triton_time_total_ms": None, + "pytorch_solve_ms": None, + "cusparse_solve_ms": None, + "pytorch/triton": None, + "cusparse/triton": None, + "status": status, + "err_ref": None, + "err_res": None, + "err_pt": None, + "err_cu": None, + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " + f"{'N/A':>10} {'N/A':>10} " + f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" {status}: {e}") + print("-" * 150) + fieldnames = [ + "format", + "matrix", + "value_dtype", + "index_dtype", + "opA", + "fill_mode", + "diag_type", + "triton_alg", + "reference_alg", + "n_rows", + "n_cols", + "nnz", + "triton_analysis_ms", + "triton_solve_ms", + "triton_time_total_ms", + "pytorch_solve_ms", + "cusparse_solve_ms", + "pytorch/triton", + "cusparse/triton", + "status", + "err_ref", + "err_res", + "err_pt", + "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + +def run_all_dtypes_spsv_coo_csv( + mtx_paths, + csv_path, + coo_mode="auto", + lower=True, + value_dtypes=None, + index_dtypes=None, +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + rows_out = [] + selected_value_dtypes = value_dtypes or VALUE_DTYPES + selected_index_dtypes = index_dtypes or INDEX_DTYPES + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + print("=" * 150) + print( + f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | COO" + f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}" + ) + print( + "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " + "RHS is generated directly, matching Library-main's SpSV test style." + ) + print( + "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|A*x-b|, " + "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " + "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " + f"{'FS.anlys':>10} {'FS.solve':>10} {'FS.total':>10} {'cu.solve':>10} {'pt.solve':>10} " + f"{'cu/triton':>10} {'pt/triton':>10} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" + ) + print("-" * 150) + for path in mtx_paths: + try: + row, pt_skip = _run_one_csv_row_coo( + path, value_dtype, index_dtype, device, coo_mode, lower=lower + ) + rows_out.append(row) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + n_rows, n_cols = row["n_rows"], row["n_cols"] + nnz = row["nnz"] + analysis_ms = row["triton_analysis_ms"] + t_ms = row["triton_solve_ms"] + cupy_ms = row["cusparse_solve_ms"] + pytorch_ms = row["pytorch_solve_ms"] + err_ref, err_res = row["err_ref"], row["err_res"] + err_pt, err_cu = row["err_pt"], row["err_cu"] + status = row["status"] + print( + f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " + f"{_fmt_ms(analysis_ms):>10} {_fmt_ms(t_ms):>10} {_fmt_ms(row['triton_time_total_ms']):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(pytorch_ms):>10} " + f"{_fmt_speedup(cupy_ms, t_ms):>10} {_fmt_speedup(pytorch_ms, t_ms):>10} " + f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" + ) + if pt_skip: + print(f" NOTE: {pt_skip}") + except Exception as e: + err_msg = str(e) + status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" + rows_out.append( + { + "format": "COO", + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": "NON", + "fill_mode": _fill_mode_name(lower), + "diag_type": _diag_type_name(False), + "triton_alg": _triton_alg_name("COO", op_mode="NON", coo_mode=coo_mode), + "reference_alg": _cusparse_alg_name(), + "n_rows": "ERR", + "n_cols": "ERR", + "nnz": "ERR", + "triton_analysis_ms": None, + "triton_solve_ms": None, + "triton_time_total_ms": None, + "pytorch_solve_ms": None, + "cusparse_solve_ms": None, + "pytorch/triton": None, + "cusparse/triton": None, + "status": status, + "err_ref": None, + "err_res": None, + "err_pt": None, + "err_cu": None, + } + ) + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name = name + "…" + print( + f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " + f"{'N/A':>10} {'N/A':>10} {status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" {status}: {e}") + print("-" * 150) + fieldnames = [ + "format", + "matrix", + "value_dtype", + "index_dtype", + "opA", + "fill_mode", + "diag_type", + "triton_alg", + "reference_alg", + "n_rows", + "n_cols", + "nnz", + "triton_analysis_ms", + "triton_solve_ms", + "triton_time_total_ms", + "pytorch_solve_ms", + "cusparse_solve_ms", + "pytorch/triton", + "cusparse/triton", + "status", + "err_ref", + "err_res", + "err_pt", + "err_cu", + ] + with open(csv_path, "w", newline="", encoding="utf-8") as f: + w = csv.DictWriter(f, fieldnames=fieldnames) + w.writeheader() + for r in rows_out: + w.writerow(r) + print(f"Wrote {len(rows_out)} rows to {csv_path}") + + +def _check_one_csr_transpose_case(path, value_dtype, index_dtype, op_mode, device, lower=True): + data, indices, indptr, shape = _load_mtx_to_csr_torch( + path, dtype=value_dtype, device=device, lower=lower + ) + indices = indices.to(index_dtype) + indptr = indptr.to(index_dtype) + n_rows, n_cols = shape + trans_data, trans_indices64, trans_indptr64 = fs_spsv_impl._csr_transpose( + data, + indices.to(torch.int64), + indptr.to(torch.int64), + n_rows, + n_cols, + conjugate=(op_mode == "CONJ"), + ) + trans_shape = (n_cols, n_rows) + trans_indices = trans_indices64.to(index_dtype) + trans_indptr = trans_indptr64.to(index_dtype) + + probe = _random_rhs_for_spsv( + trans_shape, + value_dtype, + device, + op_mode="NON", + seed=_stable_case_seed( + "check-transpose-action", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + action_ref = _apply_csr_op(data, indices, indptr, probe, shape, op_mode) + action_trans = _apply_csr_op( + trans_data, trans_indices, trans_indptr, probe, trans_shape, "NON" + ) + action_err = ( + float(torch.max(torch.abs(action_trans - action_ref)).item()) + if action_ref.numel() > 0 + else 0.0 + ) + atol, rtol = _tol_for_dtype(value_dtype) + action_ok = torch.allclose(action_trans, action_ref, atol=atol, rtol=rtol) + + b = _random_rhs_for_spsv( + shape, + value_dtype, + device, + op_mode=op_mode, + seed=_stable_case_seed( + "check-transpose-solve", + os.path.basename(path), + "LOWER" if lower else "UPPER", + op_mode, + _dtype_name(value_dtype), + _dtype_name(index_dtype), + ), + ) + x_op = fs.flagsparse_spsv_csr( + data, + indices, + indptr, + b, + shape, + lower=lower, + transpose=op_mode, + ) + x_mat = fs.flagsparse_spsv_csr( + trans_data, + trans_indices, + trans_indptr, + b, + trans_shape, + lower=not lower, + transpose="NON", + ) + solve_err = ( + float(torch.max(torch.abs(x_op - x_mat)).item()) if x_op.numel() > 0 else 0.0 + ) + solve_ok = torch.allclose(x_op, x_mat, atol=atol, rtol=rtol) + + ref_err = None + ref_ok = None + if _allow_dense_pytorch_ref(shape, value_dtype): + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) + ref_err = ( + float(torch.max(torch.abs(x_op - x_ref)).item()) if x_op.numel() > 0 else 0.0 + ) + ref_ok = torch.allclose(x_op, x_ref, atol=atol, rtol=rtol) + + status = "PASS" if action_ok and solve_ok and (ref_ok is not False) else "FAIL" + return { + "matrix": os.path.basename(path), + "value_dtype": _dtype_name(value_dtype), + "index_dtype": _dtype_name(index_dtype), + "opA": op_mode, + "n_rows": n_rows, + "nnz": int(data.numel()), + "action_err": action_err, + "solve_err": solve_err, + "ref_err": ref_err, + "status": status, + } + + +def run_csr_transpose_check( + mtx_paths, + lower=True, + value_dtypes=None, + index_dtypes=None, + op_modes=None, +): + if not torch.cuda.is_available(): + print("CUDA is not available.") + return + device = torch.device("cuda") + selected_value_dtypes = value_dtypes or CSR_FULL_VALUE_DTYPES + selected_index_dtypes = index_dtypes or CSR_FULL_INDEX_DTYPES + selected_op_modes = [op for op in (op_modes or ("TRANS", "CONJ")) if op in ("TRANS", "CONJ")] + if not selected_op_modes: + print("--check-transpose only checks TRANS/CONJ; no matching op selected.") + return + + print("=" * 150) + print( + "CSR TRANS/CONJ preprocessing check: " + "ActionErr compares materialized op(A) against direct CSR scatter; " + "SolveErr compares transpose path against materialized NON path." + ) + print("-" * 150) + print( + f"{'Matrix':<28} {'dtype':>10} {'index':>7} {'opA':>5} " + f"{'N':>7} {'NNZ':>10} {'Status':>6} {'ActionErr':>10} {'SolveErr':>10} {'RefErr':>10}" + ) + print("-" * 150) + total = 0 + failed = 0 + for value_dtype in selected_value_dtypes: + for index_dtype in selected_index_dtypes: + for op_mode in selected_op_modes: + for path in mtx_paths: + try: + row = _check_one_csr_transpose_case( + path, + value_dtype, + index_dtype, + op_mode, + device, + lower=lower, + ) + total += 1 + failed += int(row["status"] != "PASS") + name = row["matrix"][:27] + if len(row["matrix"]) > 27: + name += "..." + print( + f"{name:<28} {row['value_dtype']:>10} {row['index_dtype']:>7} {row['opA']:>5} " + f"{row['n_rows']:>7} {row['nnz']:>10} {row['status']:>6} " + f"{_fmt_err(row['action_err']):>10} {_fmt_err(row['solve_err']):>10} {_fmt_err(row['ref_err']):>10}" + ) + except Exception as e: + total += 1 + failed += 1 + name = os.path.basename(path)[:27] + if len(os.path.basename(path)) > 27: + name += "..." + print( + f"{name:<28} {_dtype_name(value_dtype):>10} {_dtype_name(index_dtype):>7} {op_mode:>5} " + f"{'ERR':>7} {'ERR':>10} {'ERROR':>6} " + f"{_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" + ) + print(f" ERROR: {e}") + print("-" * 150) + print(f"Total cases: {total} Failed: {failed}") + + +def main(): + parser = argparse.ArgumentParser( + description="SpSV test: synthetic triangular systems and optional .mtx (CSR/COO), same baselines as CSR." + ) + parser.add_argument( + "mtx", + nargs="*", + help=".mtx file path(s), or directory(ies) to glob for *.mtx", + ) + parser.add_argument( + "--synthetic", action="store_true", help="Run synthetic triangular tests" + ) + parser.add_argument( + "--csv-csr", + type=str, + default=None, + metavar="FILE", + help="Run full supported CSR SpSV combinations (dtype/index/opA) on .mtx and export CSV", + ) + parser.add_argument( + "--csv-coo", + type=str, + default=None, + metavar="FILE", + help="Run all dtypes on .mtx (COO SpSV), same CSV columns as --csv-csr", + ) + parser.add_argument( + "--check-transpose", + action="store_true", + help="Check CSR TRANS/CONJ preprocessing against direct CSR scatter and materialized NON solve", + ) + parser.add_argument( + "--coo-mode", + type=str, + default="auto", + choices=["auto", "direct", "csr"], + help="COO mode for --csv-coo (default: auto)", + ) + parser.add_argument( + "--upper", + action="store_true", + help="Use upper-triangular inputs instead of the default lower-triangular inputs", + ) + parser.add_argument( + "--ops", + type=str, + default=None, + help="Comma-separated opA filter for CSR CSV, e.g. TRANS,CONJ", + ) + parser.add_argument( + "--value-dtypes", + type=str, + default=None, + help="Comma-separated value dtype filter for CSR CSV, e.g. float,double,complex64,complex128", + ) + parser.add_argument( + "--index-dtypes", + type=str, + default=None, + help="Comma-separated index dtype filter for CSR CSV, e.g. int32,int64", + ) + args = parser.parse_args() + lower = not args.upper + + if args.synthetic: + run_spsv_synthetic_all(lower=lower) + return + + paths = [] + for p in args.mtx: + if os.path.isfile(p) and p.endswith(".mtx"): + paths.append(p) + elif os.path.isdir(p): + paths.extend(sorted(glob.glob(os.path.join(p, "*.mtx")))) + if args.check_transpose: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --check-transpose") + return + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + op_modes = ( + _parse_op_modes_filter(args.ops) + if args.ops + else None + ) + run_csr_transpose_check( + paths, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + op_modes=op_modes, + ) + return + if args.csv_csr: + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --csv-csr") + return + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + op_modes = ( + _parse_op_modes_filter(args.ops) + if args.ops + else None + ) + run_all_supported_spsv_csr_csv( + paths, + args.csv_csr, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + op_modes=op_modes, + ) + return + if args.csv_coo: + if args.ops: + parser.error("--ops is only supported with --csv-csr; COO tests only run opA=NON") + if not paths: + paths = sorted(glob.glob("*.mtx")) + if not paths: + print("No .mtx files found for --csv-coo") + return + value_dtypes = ( + _parse_value_dtypes_filter(args.value_dtypes) + if args.value_dtypes + else None + ) + index_dtypes = ( + _parse_index_dtypes_filter(args.index_dtypes) + if args.index_dtypes + else None + ) + run_all_dtypes_spsv_coo_csv( + paths, + args.csv_coo, + coo_mode=args.coo_mode, + lower=lower, + value_dtypes=value_dtypes, + index_dtypes=index_dtypes, + ) + return + + print("Use --synthetic, --csv-csr, or --csv-coo to run SpSV tests.") + + +if __name__ == "__main__": + main() From 655d71204da3d79c81980098f26e0f994e476a3b Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 13:02:24 +0800 Subject: [PATCH 10/22] pytorch --- tests/test_spsv.py | 133 ++++++++++++++++++++------------------------- 1 file changed, 59 insertions(+), 74 deletions(-) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 8853dbd..14430f3 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -37,7 +37,6 @@ WARMUP = 5 ITERS = 20 -DENSE_REF_MAX_BYTES = 2 * 1024 * 1024 * 1024 # 2 GiB SPSV_TRIANGULAR_DIAG_DOMINANCE = 4.0 # CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) CSR_FULL_VALUE_DTYPES = [ @@ -231,13 +230,6 @@ def _cusparse_alg_name(): return "CUPY_SPSOLVE_TRIANGULAR" -def _allow_dense_pytorch_ref(shape, dtype): - n_rows, n_cols = int(shape[0]), int(shape[1]) - elem_bytes = torch.empty((), dtype=dtype).element_size() - dense_bytes = n_rows * n_cols * elem_bytes - return dense_bytes <= DENSE_REF_MAX_BYTES - - def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True): """Build a well-conditioned triangular CSR for real and complex dtypes.""" max_bandwidth = max(4, min(n, 16)) @@ -986,41 +978,36 @@ def _finalize_csv_row( err_pt = None ok_pt = False pt_skip_reason = None - if _allow_dense_pytorch_ref(shape, value_dtype): - try: - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr, shape - ) - ref_dtype = _dense_ref_dtype(value_dtype) - A_ref = A_dense.to(ref_dtype) - b_ref = b.to(ref_dtype) - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - torch.cuda.synchronize() - e0.record() - x_ref = _triangular_solve_reference( - A_ref, b_ref, lower=lower, op_mode="NON" - ) - x_cmp = _compare_view(x, value_dtype) - x_ref_cmp = _compare_view(x_ref, value_dtype) - e1.record() - torch.cuda.synchronize() - pytorch_ms = e0.elapsed_time(e1) - err_pt = ( - float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) - if n_rows > 0 - else 0.0 - ) - ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) - except RuntimeError as e: - if "out of memory" in str(e).lower(): - pt_skip_reason = "PyTorch dense ref OOM; skipped" - else: - raise - else: - pt_skip_reason = ( - f"PyTorch dense ref skipped (> {DENSE_REF_MAX_BYTES // (1024**3)} GiB dense matrix)" + try: + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr, shape + ) + ref_dtype = _dense_ref_dtype(value_dtype) + A_ref = A_dense.to(ref_dtype) + b_ref = b.to(ref_dtype) + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + torch.cuda.synchronize() + e0.record() + x_ref = _triangular_solve_reference( + A_ref, b_ref, lower=lower, op_mode="NON" ) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = ( + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) + if n_rows > 0 + else 0.0 + ) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pt_skip_reason = "PyTorch dense ref OOM; skipped" + else: + raise cupy_ms = None err_cu = None @@ -1212,38 +1199,33 @@ def _finalize_csv_row_csr_full( err_pt = None ok_pt = False pt_skip_reason = None - if _allow_dense_pytorch_ref(shape, value_dtype): - try: - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr.to(torch.int64), shape - ).to(_dense_ref_dtype(value_dtype)) - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - torch.cuda.synchronize() - e0.record() - x_ref = _triangular_solve_reference( - A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode - ) - x_cmp = _compare_view(x, value_dtype) - x_ref_cmp = _compare_view(x_ref, value_dtype) - e1.record() - torch.cuda.synchronize() - pytorch_ms = e0.elapsed_time(e1) - err_pt = ( - float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) - if n_rows > 0 - else 0.0 - ) - ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) - except RuntimeError as e: - if "out of memory" in str(e).lower(): - pt_skip_reason = "PyTorch dense ref OOM; skipped" - else: - raise - else: - pt_skip_reason = ( - f"PyTorch dense ref skipped (> {DENSE_REF_MAX_BYTES // (1024**3)} GiB dense matrix)" + try: + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(value_dtype)) + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + torch.cuda.synchronize() + e0.record() + x_ref = _triangular_solve_reference( + A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode + ) + x_cmp = _compare_view(x, value_dtype) + x_ref_cmp = _compare_view(x_ref, value_dtype) + e1.record() + torch.cuda.synchronize() + pytorch_ms = e0.elapsed_time(e1) + err_pt = ( + float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) + if n_rows > 0 + else 0.0 ) + ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + pt_skip_reason = "PyTorch dense ref OOM; skipped" + else: + raise cupy_ms = None err_cu = None @@ -1666,7 +1648,7 @@ def _check_one_csr_transpose_case(path, value_dtype, index_dtype, op_mode, devic ref_err = None ref_ok = None - if _allow_dense_pytorch_ref(shape, value_dtype): + try: A_dense = _csr_to_dense( data, indices.to(torch.int64), indptr.to(torch.int64), shape ).to(_dense_ref_dtype(value_dtype)) @@ -1677,6 +1659,9 @@ def _check_one_csr_transpose_case(path, value_dtype, index_dtype, op_mode, devic float(torch.max(torch.abs(x_op - x_ref)).item()) if x_op.numel() > 0 else 0.0 ) ref_ok = torch.allclose(x_op, x_ref, atol=atol, rtol=rtol) + except RuntimeError as e: + if "out of memory" not in str(e).lower(): + raise status = "PASS" if action_ok and solve_ok and (ref_ok is not False) else "FAIL" return { From 0413508b1a3510cd6256a5a480f34e9526d350e1 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 14:21:58 +0800 Subject: [PATCH 11/22] pytorch --- tests/test_spsv.py | 209 ++++++++++++++++++++++++++++++--------------- 1 file changed, 139 insertions(+), 70 deletions(-) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 14430f3..2a4442a 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -100,7 +100,7 @@ def _fmt_ms(v): def _fmt_speedup(other_ms, triton_ms): if other_ms is None or triton_ms is None or triton_ms <= 0: return "N/A" - return f"{other_ms / triton_ms:.2f}x" + return f"{other_ms / triton_ms:.2f}" def _fmt_err(v): @@ -190,6 +190,92 @@ def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): ).squeeze(1) +def _build_csr_tensor_for_op(data, indices, indptr, shape, op_mode): + if op_mode == "TRANS": + data_eff, indices_eff, indptr_eff = _csr_transpose( + data, indices.to(torch.int64), indptr.to(torch.int64), shape, conjugate=False + ) + elif op_mode == "CONJ": + data_eff, indices_eff, indptr_eff = _csr_transpose( + data, indices.to(torch.int64), indptr.to(torch.int64), shape, conjugate=True + ) + else: + data_eff = data + indices_eff = indices.to(torch.int64) + indptr_eff = indptr.to(torch.int64) + return torch.sparse_csr_tensor( + indptr_eff, + indices_eff, + data_eff, + size=shape, + device=data.device, + ) + + +def _benchmark_pytorch_sparse_reference(data, indices, indptr, shape, b, *, op_mode): + sparse_spsolve = getattr(torch.sparse, "spsolve", None) + if sparse_spsolve is None: + raise NotImplementedError("torch.sparse.spsolve is unavailable") + A_csr = _build_csr_tensor_for_op(data, indices, indptr, shape, op_mode) + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_ref = sparse_spsolve(A_csr, b) + e1.record() + torch.cuda.synchronize() + return x_ref.to(b.dtype), e0.elapsed_time(e1) + + +def _benchmark_pytorch_dense_reference(data, indices, indptr, shape, b, *, lower, op_mode): + A_dense = _csr_to_dense( + data, indices.to(torch.int64), indptr.to(torch.int64), shape + ).to(_dense_ref_dtype(b.dtype)) + b_ref = b.to(A_dense.dtype) + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_ref = _triangular_solve_reference( + A_dense, b_ref, lower=lower, op_mode=op_mode + ) + e1.record() + torch.cuda.synchronize() + return x_ref.to(b.dtype), e0.elapsed_time(e1) + + +def _benchmark_pytorch_reference(data, indices, indptr, shape, b, *, lower, op_mode): + sparse_err = None + try: + x_ref, ms = _benchmark_pytorch_sparse_reference( + data, indices, indptr, shape, b, op_mode=op_mode + ) + return x_ref, ms, None + except RuntimeError as e: + sparse_err = e + if "out of memory" in str(e).lower(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + else: + sparse_err = e + except NotImplementedError as e: + sparse_err = e + + try: + x_ref, ms = _benchmark_pytorch_dense_reference( + data, indices, indptr, shape, b, lower=lower, op_mode=op_mode + ) + return x_ref, ms, None + except RuntimeError as e: + if "out of memory" in str(e).lower(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + if sparse_err is not None: + return None, None, "PyTorch sparse ref unavailable; dense ref OOM; skipped" + return None, None, "PyTorch dense ref OOM; skipped" + raise + + def _cupy_ref_inputs(data, b): return data, b @@ -766,9 +852,6 @@ def run_spsv_synthetic_all(lower=True): data, indices, indptr, shape = _build_random_triangular_csr( n, value_dtype, index_dtype, device, lower=lower ) - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr, shape - ) rhs_op = op_mode if fmt == "CSR" else "NON" b = _random_rhs_for_spsv( shape, @@ -813,19 +896,20 @@ def run_spsv_synthetic_all(lower=True): ) torch.cuda.synchronize() - A_ref = A_dense - b_ref = b - torch.cuda.synchronize() - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - e0.record() - x_pt = _triangular_solve_reference( - A_ref, b_ref, lower=lower, op_mode=op_mode + x_pt, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + data, + indices, + indptr, + shape, + b, + lower=lower, + op_mode=op_mode, + ) + err_pt = ( + float(torch.max(torch.abs(x - x_pt)).item()) + if (x_pt is not None and n > 0) + else None ) - e1.record() - torch.cuda.synchronize() - pytorch_ms = e0.elapsed_time(e1) - err_pt = float(torch.max(torch.abs(x - x_pt)).item()) if n > 0 else 0.0 cupy_ms = None err_cu = None @@ -857,7 +941,11 @@ def run_spsv_synthetic_all(lower=True): ) atol, rtol = _tol_for_dtype(value_dtype) - ok_pt = torch.allclose(x, x_pt, atol=atol, rtol=rtol) + ok_pt = ( + torch.allclose(x, x_pt, atol=atol, rtol=rtol) + if x_pt is not None + else False + ) ok_cu = ( True if x_cu_t is None @@ -880,10 +968,12 @@ def run_spsv_synthetic_all(lower=True): print( f"{fmt:>5} {op_mode:>5} {n:>6} {_fmt_ms(analysis_ms):>12} {_fmt_ms(t_ms):>14} {_fmt_ms(_sum_ms(analysis_ms, t_ms)):>10} {_fmt_ms(pytorch_ms):>12} " f"{_fmt_ms(cupy_ms):>10} " - f"{(f'{fs_vs_pt:.2f}x' if fs_vs_pt is not None else 'N/A'):>10} " - f"{(f'{fs_vs_cu:.2f}x' if fs_vs_cu is not None else 'N/A'):>10} " + f"{(f'{fs_vs_pt:.2f}' if fs_vs_pt is not None else 'N/A'):>10} " + f"{(f'{fs_vs_cu:.2f}' if fs_vs_cu is not None else 'N/A'):>10} " f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" ) + if pt_skip_reason: + print(f" NOTE: {pt_skip_reason}") print("-" * 110) print() @@ -978,36 +1068,24 @@ def _finalize_csv_row( err_pt = None ok_pt = False pt_skip_reason = None - try: - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr, shape - ) - ref_dtype = _dense_ref_dtype(value_dtype) - A_ref = A_dense.to(ref_dtype) - b_ref = b.to(ref_dtype) - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - torch.cuda.synchronize() - e0.record() - x_ref = _triangular_solve_reference( - A_ref, b_ref, lower=lower, op_mode="NON" - ) + x_ref, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + data, + indices, + indptr, + shape, + b, + lower=lower, + op_mode="NON", + ) + if x_ref is not None: x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) - e1.record() - torch.cuda.synchronize() - pytorch_ms = e0.elapsed_time(e1) err_pt = ( float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) if n_rows > 0 else 0.0 ) ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) - except RuntimeError as e: - if "out of memory" in str(e).lower(): - pt_skip_reason = "PyTorch dense ref OOM; skipped" - else: - raise cupy_ms = None err_cu = None @@ -1199,33 +1277,24 @@ def _finalize_csv_row_csr_full( err_pt = None ok_pt = False pt_skip_reason = None - try: - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr.to(torch.int64), shape - ).to(_dense_ref_dtype(value_dtype)) - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - torch.cuda.synchronize() - e0.record() - x_ref = _triangular_solve_reference( - A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode - ) + x_ref, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + data, + indices, + indptr, + shape, + b, + lower=lower, + op_mode=op_mode, + ) + if x_ref is not None: x_cmp = _compare_view(x, value_dtype) x_ref_cmp = _compare_view(x_ref, value_dtype) - e1.record() - torch.cuda.synchronize() - pytorch_ms = e0.elapsed_time(e1) err_pt = ( float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) if n_rows > 0 else 0.0 ) ok_pt = torch.allclose(x_cmp, x_ref_cmp, atol=atol, rtol=rtol) - except RuntimeError as e: - if "out of memory" in str(e).lower(): - pt_skip_reason = "PyTorch dense ref OOM; skipped" - else: - raise cupy_ms = None err_cu = None @@ -1648,20 +1717,20 @@ def _check_one_csr_transpose_case(path, value_dtype, index_dtype, op_mode, devic ref_err = None ref_ok = None - try: - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr.to(torch.int64), shape - ).to(_dense_ref_dtype(value_dtype)) - x_ref = _triangular_solve_reference( - A_dense, b.to(A_dense.dtype), lower=lower, op_mode=op_mode - ) + x_ref, _, _ = _benchmark_pytorch_reference( + data, + indices, + indptr, + shape, + b, + lower=lower, + op_mode=op_mode, + ) + if x_ref is not None: ref_err = ( float(torch.max(torch.abs(x_op - x_ref)).item()) if x_op.numel() > 0 else 0.0 ) ref_ok = torch.allclose(x_op, x_ref, atol=atol, rtol=rtol) - except RuntimeError as e: - if "out of memory" not in str(e).lower(): - raise status = "PASS" if action_ok and solve_ok and (ref_ok is not False) else "FAIL" return { From 92717a6e3af495764ab7b294bfe88d0b99593cc8 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 14:55:25 +0800 Subject: [PATCH 12/22] pytorch --- tests/test_spsv.py | 213 ++++++++++++++++++++++++++++----------------- 1 file changed, 135 insertions(+), 78 deletions(-) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 2a4442a..c97799f 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -119,6 +119,12 @@ def _sum_ms(*values): return sum(values) +def _status_str(ok_flag, has_value): + if ok_flag: + return "PASS" + return "FAIL" if has_value else "N/A" + + def _tol_for_dtype(dtype): if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 @@ -217,14 +223,20 @@ def _benchmark_pytorch_sparse_reference(data, indices, indptr, shape, b, *, op_m if sparse_spsolve is None: raise NotImplementedError("torch.sparse.spsolve is unavailable") A_csr = _build_csr_tensor_for_op(data, indices, indptr, shape, op_mode) - torch.cuda.synchronize() - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - e0.record() - x_ref = sparse_spsolve(A_csr, b) - e1.record() - torch.cuda.synchronize() - return x_ref.to(b.dtype), e0.elapsed_time(e1) + if A_csr.is_cuda: + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_ref = sparse_spsolve(A_csr, b) + e1.record() + torch.cuda.synchronize() + elapsed_ms = e0.elapsed_time(e1) + else: + t0 = time.perf_counter() + x_ref = sparse_spsolve(A_csr, b) + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + return x_ref.to(b.dtype), elapsed_ms def _benchmark_pytorch_dense_reference(data, indices, indptr, shape, b, *, lower, op_mode): @@ -232,16 +244,24 @@ def _benchmark_pytorch_dense_reference(data, indices, indptr, shape, b, *, lower data, indices.to(torch.int64), indptr.to(torch.int64), shape ).to(_dense_ref_dtype(b.dtype)) b_ref = b.to(A_dense.dtype) - torch.cuda.synchronize() - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - e0.record() - x_ref = _triangular_solve_reference( - A_dense, b_ref, lower=lower, op_mode=op_mode - ) - e1.record() - torch.cuda.synchronize() - return x_ref.to(b.dtype), e0.elapsed_time(e1) + if A_dense.is_cuda: + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_ref = _triangular_solve_reference( + A_dense, b_ref, lower=lower, op_mode=op_mode + ) + e1.record() + torch.cuda.synchronize() + elapsed_ms = e0.elapsed_time(e1) + else: + t0 = time.perf_counter() + x_ref = _triangular_solve_reference( + A_dense, b_ref, lower=lower, op_mode=op_mode + ) + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + return x_ref.to(b.dtype), elapsed_ms def _benchmark_pytorch_reference(data, indices, indptr, shape, b, *, lower, op_mode): @@ -250,29 +270,60 @@ def _benchmark_pytorch_reference(data, indices, indptr, shape, b, *, lower, op_m x_ref, ms = _benchmark_pytorch_sparse_reference( data, indices, indptr, shape, b, op_mode=op_mode ) - return x_ref, ms, None + return x_ref, ms, None, None except RuntimeError as e: sparse_err = e if "out of memory" in str(e).lower(): if torch.cuda.is_available(): torch.cuda.empty_cache() - else: - sparse_err = e except NotImplementedError as e: sparse_err = e try: - x_ref, ms = _benchmark_pytorch_dense_reference( - data, indices, indptr, shape, b, lower=lower, op_mode=op_mode + data_cpu = data.detach().cpu() + indices_cpu = indices.to(torch.int64).detach().cpu() + indptr_cpu = indptr.to(torch.int64).detach().cpu() + b_cpu = b.detach().cpu() + x_ref_cpu, ms_cpu = _benchmark_pytorch_sparse_reference( + data_cpu, indices_cpu, indptr_cpu, shape, b_cpu, op_mode=op_mode + ) + return ( + x_ref_cpu.to(b.device, dtype=b.dtype), + None, + ms_cpu, + "PyTorch ref used CPU sparse.spsolve", + ) + except RuntimeError as e: + sparse_err = e + except NotImplementedError: + pass + + try: + data_cpu = data.detach().cpu() + indices_cpu = indices.to(torch.int64).detach().cpu() + indptr_cpu = indptr.to(torch.int64).detach().cpu() + b_cpu = b.detach().cpu() + x_ref_cpu, ms_cpu = _benchmark_pytorch_dense_reference( + data_cpu, + indices_cpu, + indptr_cpu, + shape, + b_cpu, + lower=lower, + op_mode=op_mode, + ) + note = "PyTorch sparse ref unavailable; CPU dense solve_triangular fallback" + if sparse_err is None: + note = "PyTorch ref used CPU dense solve_triangular" + return ( + x_ref_cpu.to(b.device, dtype=b.dtype), + None, + ms_cpu, + note, ) - return x_ref, ms, None except RuntimeError as e: if "out of memory" in str(e).lower(): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - if sparse_err is not None: - return None, None, "PyTorch sparse ref unavailable; dense ref OOM; skipped" - return None, None, "PyTorch dense ref OOM; skipped" + return None, None, None, "PyTorch sparse ref unavailable; CPU dense ref OOM; skipped" raise @@ -896,7 +947,7 @@ def run_spsv_synthetic_all(lower=True): ) torch.cuda.synchronize() - x_pt, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_pt, pytorch_ms, _pytorch_cpu_ms, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -1068,7 +1119,7 @@ def _finalize_csv_row( err_pt = None ok_pt = False pt_skip_reason = None - x_ref, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_ref, pytorch_ms, _pytorch_cpu_ms, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -1174,30 +1225,33 @@ def _finalize_csv_row( int(data.numel()) if nnz_display is None else int(nnz_display) ) row = { - "format": "COO", "matrix": os.path.basename(path), "value_dtype": _dtype_name(value_dtype), "index_dtype": _dtype_name(index_dtype), "opA": "NON", - "fill_mode": _fill_mode_name(lower), - "diag_type": _diag_type_name(False), - "triton_alg": _triton_alg_name("COO", op_mode="NON", coo_mode=coo_mode), - "reference_alg": _cusparse_alg_name(), "n_rows": n_rows, "n_cols": n_cols, "nnz": nnz_out, "triton_analysis_ms": analysis_ms, "triton_solve_ms": t_ms, "triton_time_total_ms": _sum_ms(analysis_ms, t_ms), - "pytorch_solve_ms": pytorch_ms, "cusparse_solve_ms": cupy_ms, - "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), + "pytorch_solve_ms": pytorch_ms, "cusparse/triton": _safe_ratio(cupy_ms, t_ms), + "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), + "pt_status": _status_str(ok_pt, err_pt is not None), + "cu_status": _status_str(ok_cu, err_cu is not None), "status": status, "err_ref": err_ref, "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, + "pytorch_reason": pt_skip_reason, + "cusparse_reason": None if (cupy_ms is not None or x_cu_t is not None) else ( + "CuPy/cuSPARSE unavailable" if (cp is None or cpx_sparse is None or cpx_spsolve_triangular is None) + else "cuSPARSE solve failed" + ), + "error": None, } return row, pt_skip_reason @@ -1277,7 +1331,7 @@ def _finalize_csv_row_csr_full( err_pt = None ok_pt = False pt_skip_reason = None - x_ref, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_ref, pytorch_ms, _pytorch_cpu_ms, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -1320,30 +1374,33 @@ def _finalize_csv_row_csr_full( err_ref = min(ref_errors) if ref_errors else None row = { - "format": "CSR", "matrix": os.path.basename(path), "value_dtype": _dtype_name(value_dtype), "index_dtype": _dtype_name(index_dtype), "opA": op_mode, - "fill_mode": _fill_mode_name(lower), - "diag_type": _diag_type_name(False), - "triton_alg": _triton_alg_name("CSR", op_mode=op_mode), - "reference_alg": _cusparse_alg_name(), "n_rows": n_rows, "n_cols": n_cols, "nnz": int(data.numel()), "triton_analysis_ms": analysis_ms, "triton_solve_ms": t_ms, "triton_time_total_ms": _sum_ms(analysis_ms, t_ms), - "pytorch_solve_ms": pytorch_ms, "cusparse_solve_ms": cupy_ms, - "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), + "pytorch_solve_ms": pytorch_ms, "cusparse/triton": _safe_ratio(cupy_ms, t_ms), + "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), + "pt_status": _status_str(ok_pt, err_pt is not None), + "cu_status": _status_str(ok_cu, err_cu is not None), "status": status, "err_ref": err_ref, "err_res": err_res, "err_pt": err_pt, "err_cu": err_cu, + "pytorch_reason": pt_skip_reason, + "cusparse_reason": None if x_cu_t is not None else ( + "CuPy/cuSPARSE unavailable" if (cp is None or cpx_sparse is None or cpx_spsolve_triangular is None) + else "cuSPARSE solve failed" + ), + "error": None, } return row, pt_skip_reason @@ -1376,7 +1433,7 @@ def run_all_supported_spsv_csr_csv( f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" ) print( - "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch=Dense solve." + "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch(ms)=CUDA sparse solve only." ) print( "RHS is generated directly, matching Library-main's SpSV test style. " @@ -1423,30 +1480,30 @@ def run_all_supported_spsv_csr_csv( status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" rows_out.append( { - "format": "CSR", "matrix": os.path.basename(path), "value_dtype": _dtype_name(value_dtype), "index_dtype": _dtype_name(index_dtype), "opA": op_mode, - "fill_mode": _fill_mode_name(lower), - "diag_type": _diag_type_name(False), - "triton_alg": _triton_alg_name("CSR", op_mode=op_mode), - "reference_alg": _cusparse_alg_name(), "n_rows": "ERR", "n_cols": "ERR", "nnz": "ERR", "triton_analysis_ms": None, "triton_solve_ms": None, "triton_time_total_ms": None, - "pytorch_solve_ms": None, "cusparse_solve_ms": None, - "pytorch/triton": None, + "pytorch_solve_ms": None, "cusparse/triton": None, + "pytorch/triton": None, + "pt_status": "N/A", + "cu_status": "N/A", "status": status, "err_ref": None, "err_res": None, "err_pt": None, "err_cu": None, + "pytorch_reason": None, + "cusparse_reason": None, + "error": err_msg, } ) name = os.path.basename(path)[:27] @@ -1461,36 +1518,36 @@ def run_all_supported_spsv_csr_csv( print(f" {status}: {e}") print("-" * 150) fieldnames = [ - "format", "matrix", "value_dtype", "index_dtype", "opA", - "fill_mode", - "diag_type", - "triton_alg", - "reference_alg", "n_rows", "n_cols", "nnz", "triton_analysis_ms", "triton_solve_ms", "triton_time_total_ms", - "pytorch_solve_ms", "cusparse_solve_ms", - "pytorch/triton", + "pytorch_solve_ms", "cusparse/triton", + "pytorch/triton", + "pt_status", + "cu_status", "status", "err_ref", "err_res", "err_pt", "err_cu", + "pytorch_reason", + "cusparse_reason", + "error", ] with open(csv_path, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows_out: - w.writerow(r) + w.writerow({k: ("" if v is None else v) for k, v in r.items()}) print(f"Wrote {len(rows_out)} rows to {csv_path}") @@ -1517,7 +1574,7 @@ def run_all_dtypes_spsv_coo_csv( f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}" ) print( - "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch=Dense solve. " + "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch(ms)=CUDA sparse solve only. " "RHS is generated directly, matching Library-main's SpSV test style." ) print( @@ -1563,30 +1620,30 @@ def run_all_dtypes_spsv_coo_csv( status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" rows_out.append( { - "format": "COO", "matrix": os.path.basename(path), "value_dtype": _dtype_name(value_dtype), "index_dtype": _dtype_name(index_dtype), "opA": "NON", - "fill_mode": _fill_mode_name(lower), - "diag_type": _diag_type_name(False), - "triton_alg": _triton_alg_name("COO", op_mode="NON", coo_mode=coo_mode), - "reference_alg": _cusparse_alg_name(), "n_rows": "ERR", "n_cols": "ERR", "nnz": "ERR", "triton_analysis_ms": None, "triton_solve_ms": None, "triton_time_total_ms": None, - "pytorch_solve_ms": None, "cusparse_solve_ms": None, - "pytorch/triton": None, + "pytorch_solve_ms": None, "cusparse/triton": None, + "pytorch/triton": None, + "pt_status": "N/A", + "cu_status": "N/A", "status": status, "err_ref": None, "err_res": None, "err_pt": None, "err_cu": None, + "pytorch_reason": None, + "cusparse_reason": None, + "error": err_msg, } ) name = os.path.basename(path)[:27] @@ -1600,36 +1657,36 @@ def run_all_dtypes_spsv_coo_csv( print(f" {status}: {e}") print("-" * 150) fieldnames = [ - "format", "matrix", "value_dtype", "index_dtype", "opA", - "fill_mode", - "diag_type", - "triton_alg", - "reference_alg", "n_rows", "n_cols", "nnz", "triton_analysis_ms", "triton_solve_ms", "triton_time_total_ms", - "pytorch_solve_ms", "cusparse_solve_ms", - "pytorch/triton", + "pytorch_solve_ms", "cusparse/triton", + "pytorch/triton", + "pt_status", + "cu_status", "status", "err_ref", "err_res", "err_pt", "err_cu", + "pytorch_reason", + "cusparse_reason", + "error", ] with open(csv_path, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows_out: - w.writerow(r) + w.writerow({k: ("" if v is None else v) for k, v in r.items()}) print(f"Wrote {len(rows_out)} rows to {csv_path}") From d463197b321e5df7e7e61ffcefab6110f5679bb1 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 16:15:47 +0800 Subject: [PATCH 13/22] pytorch --- src/flagsparse/sparse_operations/spsv.py | 33 +--- tests/pytest/test_spsv_csr_accuracy.py | 41 ++--- tests/test_spsv.py | 207 ++++------------------- 3 files changed, 45 insertions(+), 236 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index a4952e5..0d31194 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -54,27 +54,6 @@ def _clear_spsv_csr_preprocess_cache(): _SPSV_CSR_PREPROCESS_CACHE.clear() -def _csr_to_dense(data, indices, indptr, shape): - """Convert CSR (torch CUDA tensors) to dense matrix on the same device.""" - device = data.device - dtype = data.dtype - n_rows, n_cols = int(shape[0]), int(shape[1]) - if n_rows == 0 or n_cols == 0: - return torch.zeros((n_rows, n_cols), dtype=dtype, device=device) - row_ind = torch.repeat_interleave( - torch.arange(n_rows, device=device, dtype=torch.int64), - indptr[1:] - indptr[:-1], - ) - col_ind = indices.to(torch.int64) - coo = torch.sparse_coo_tensor( - torch.stack([row_ind, col_ind]), - data, - (n_rows, n_cols), - device=device, - ).coalesce() - return coo.to_dense() - - def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" if (data_dtype, index_dtype) in SPSV_NON_TRANS_SUPPORTED_COMBOS: @@ -174,15 +153,6 @@ def _prepare_spsv_inputs(data, indices, indptr, b, shape): n_cols, ) - -def _prepare_spsv_working_inputs(data, b): - return data, b, None - - -def _restore_spsv_output(x, target_dtype): - return x.to(target_dtype) - - def _spsv_diag_eps_for_dtype(value_dtype): return 1e-12 if value_dtype in (torch.float64, torch.complex128) else 1e-6 @@ -320,7 +290,6 @@ def _resolve_spsv_csr_runtime( data, indices, indptr, b, shape ) original_output_dtype = None - data, b, original_output_dtype = _prepare_spsv_working_inputs(data, b) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") if trans_mode == "N": @@ -1469,7 +1438,7 @@ def flagsparse_spsv_csr( x = torch.stack(cols, dim=1) target_dtype = original_output_dtype if original_output_dtype is not None else data.dtype if x.dtype != target_dtype: - x = _restore_spsv_output(x, target_dtype) + x = x.to(target_dtype) if return_time: torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 diff --git a/tests/pytest/test_spsv_csr_accuracy.py b/tests/pytest/test_spsv_csr_accuracy.py index fc9c891..3484f7f 100644 --- a/tests/pytest/test_spsv_csr_accuracy.py +++ b/tests/pytest/test_spsv_csr_accuracy.py @@ -42,19 +42,6 @@ def _rand_like(dtype, shape, device): i = torch.randn(shape, dtype=base, device=device) return torch.complex(r, i) - -def _ref_dtype(dtype): - return dtype - - -def _safe_cast_tensor(tensor, dtype): - return tensor.to(dtype) - - -def _cmp_view(tensor, dtype): - return tensor - - def _apply_ref_op(A, op_mode): if op_mode == "TRANS": return A.transpose(-2, -1) @@ -159,7 +146,7 @@ def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): A = _build_triangular(n, dtype, device, lower=True) b = _rand_like(dtype, (n,), device) x_ref = torch.linalg.solve_triangular( - A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=False + A.to(dtype), b.to(dtype).unsqueeze(-1), upper=False ).squeeze(-1) Asp = A.to_sparse_csr() @@ -178,7 +165,7 @@ def test_spsv_csr_non_trans_supported_combos(n, dtype, index_dtype): transpose=False, ) rtol, atol = _tol(dtype) - assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) @pytest.mark.spsv @@ -190,8 +177,8 @@ def test_spsv_csr_transpose_family_supported_combos(n, dtype, index_dtype, op_mo device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=True) b = _rand_like(dtype, (n,), device) - A_ref = A.to(_ref_dtype(dtype)) - b_ref = b.to(_ref_dtype(dtype)) + A_ref = A.to(dtype) + b_ref = b.to(dtype) x_ref = torch.linalg.solve_triangular( _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(True, op_mode) ).squeeze(-1) @@ -212,7 +199,7 @@ def test_spsv_csr_transpose_family_supported_combos(n, dtype, index_dtype, op_mo transpose=_transpose_arg(op_mode), ) rtol, atol = _tol(dtype) - assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) @pytest.mark.spsv @@ -239,7 +226,7 @@ def test_spsv_csr_matches_cusparse_non_trans(n, dtype): x_non_ref = _cupy_ref_spsv(A_cp, b, lower=True, unit_diagonal=False) rtol, atol = _tol(dtype) - assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(x_non, x_non_ref, rtol=rtol, atol=atol) @pytest.mark.spsv @@ -280,7 +267,7 @@ def test_spsv_csr_matches_cusparse_transpose_family(n, dtype, index_dtype, op_mo ) rtol, atol = _tol(dtype) - assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(x_trans, x_trans_ref, rtol=rtol, atol=atol) @pytest.mark.spsv @@ -292,7 +279,7 @@ def test_spsv_csr_non_trans_upper_supported_combos(n, dtype, index_dtype): A = _build_triangular(n, dtype, device, lower=False) b = _rand_like(dtype, (n,), device) x_ref = torch.linalg.solve_triangular( - A.to(_ref_dtype(dtype)), b.to(_ref_dtype(dtype)).unsqueeze(-1), upper=True + A.to(dtype), b.to(dtype).unsqueeze(-1), upper=True ).squeeze(-1) Asp = A.to_sparse_csr() @@ -311,7 +298,7 @@ def test_spsv_csr_non_trans_upper_supported_combos(n, dtype, index_dtype): transpose=False, ) rtol, atol = _tol(dtype) - assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) @pytest.mark.spsv @@ -323,8 +310,8 @@ def test_spsv_csr_upper_transpose_family_supported_combos(n, dtype, index_dtype, device = torch.device("cuda") A = _build_triangular(n, dtype, device, lower=False) b = _rand_like(dtype, (n,), device) - A_ref = A.to(_ref_dtype(dtype)) - b_ref = b.to(_ref_dtype(dtype)) + A_ref = A.to(dtype) + b_ref = b.to(dtype) x_ref = torch.linalg.solve_triangular( _apply_ref_op(A_ref, op_mode), b_ref.unsqueeze(-1), upper=_effective_upper(False, op_mode) ).squeeze(-1) @@ -345,7 +332,7 @@ def test_spsv_csr_upper_transpose_family_supported_combos(n, dtype, index_dtype, transpose=_transpose_arg(op_mode), ) rtol, atol = _tol(dtype) - assert torch.allclose(_cmp_view(x, dtype), _cmp_view(x_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(x, x_ref, rtol=rtol, atol=atol) @pytest.mark.spsv @@ -372,7 +359,7 @@ def test_spsv_csr_matches_cusparse_upper_non_trans(n, dtype): x_non_ref = _cupy_ref_spsv(A_cp, b, lower=False, unit_diagonal=False) rtol, atol = _tol(dtype) - assert torch.allclose(_cmp_view(x_non, dtype), _cmp_view(x_non_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(x_non, x_non_ref, rtol=rtol, atol=atol) @pytest.mark.spsv @@ -413,7 +400,7 @@ def test_spsv_csr_matches_cusparse_upper_transpose_family(n, dtype, index_dtype, ) rtol, atol = _tol(dtype) - assert torch.allclose(_cmp_view(x_trans, dtype), _cmp_view(x_trans_ref, dtype), rtol=rtol, atol=atol) + assert torch.allclose(x_trans, x_trans_ref, rtol=rtol, atol=atol) @pytest.mark.spsv diff --git a/tests/test_spsv.py b/tests/test_spsv.py index c97799f..e8964a1 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -1,7 +1,8 @@ """SpSV tests: synthetic triangular systems and optional .mtx (CSR/COO). -与 CSR 相同的计时列、PyTorch 稠密参考、CSV 字段与 PASS 判定;COO 测试时 CuPy 基线使用 -``coo_matrix``(与 FlagSparse COO 输入同构),CSR 测试时仍用 ``csr_matrix``。 +与 CSR 相同的计时列、PyTorch CUDA sparse 参考、CSV 字段与 PASS 判定;COO 测试时 +CuPy 基线使用 ``coo_matrix``(与 FlagSparse COO 输入同构),CSR 测试时仍用 +``csr_matrix``。 """ import argparse @@ -153,22 +154,10 @@ def _randn_by_dtype(n, dtype, device, generator=None): return torch.complex(real, imag) -def _dense_ref_dtype(dtype): - return dtype - - def _tensor_from_scalar_values(values, dtype, device): return torch.tensor(values, dtype=dtype, device=device) -def _safe_cast_tensor(tensor, dtype): - return tensor.to(dtype) - - -def _cast_real_tensor_to_value_dtype(values, value_dtype): - return values.to(value_dtype) - - def _matrix_market_value(parts, mm_field): if mm_field == "complex": if len(parts) < 4: @@ -180,22 +169,6 @@ def _matrix_market_value(parts, mm_field): return 1.0 raise ValueError("MatrixMarket entry is missing a numeric value") - -def _triangular_solve_reference(A, b, *, lower, op_mode="NON"): - if op_mode == "TRANS": - A_eff = A.transpose(0, 1) - upper = lower - elif op_mode == "CONJ": - A_eff = A.transpose(0, 1).conj() if torch.is_complex(A) else A.transpose(0, 1) - upper = lower - else: - A_eff = A - upper = not lower - return torch.linalg.solve_triangular( - A_eff, b.unsqueeze(1), upper=upper - ).squeeze(1) - - def _build_csr_tensor_for_op(data, indices, indptr, shape, op_mode): if op_mode == "TRANS": data_eff, indices_eff, indptr_eff = _csr_transpose( @@ -223,118 +196,36 @@ def _benchmark_pytorch_sparse_reference(data, indices, indptr, shape, b, *, op_m if sparse_spsolve is None: raise NotImplementedError("torch.sparse.spsolve is unavailable") A_csr = _build_csr_tensor_for_op(data, indices, indptr, shape, op_mode) - if A_csr.is_cuda: - torch.cuda.synchronize() - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - e0.record() - x_ref = sparse_spsolve(A_csr, b) - e1.record() - torch.cuda.synchronize() - elapsed_ms = e0.elapsed_time(e1) - else: - t0 = time.perf_counter() - x_ref = sparse_spsolve(A_csr, b) - elapsed_ms = (time.perf_counter() - t0) * 1000.0 - return x_ref.to(b.dtype), elapsed_ms - - -def _benchmark_pytorch_dense_reference(data, indices, indptr, shape, b, *, lower, op_mode): - A_dense = _csr_to_dense( - data, indices.to(torch.int64), indptr.to(torch.int64), shape - ).to(_dense_ref_dtype(b.dtype)) - b_ref = b.to(A_dense.dtype) - if A_dense.is_cuda: - torch.cuda.synchronize() - e0 = torch.cuda.Event(True) - e1 = torch.cuda.Event(True) - e0.record() - x_ref = _triangular_solve_reference( - A_dense, b_ref, lower=lower, op_mode=op_mode - ) - e1.record() - torch.cuda.synchronize() - elapsed_ms = e0.elapsed_time(e1) - else: - t0 = time.perf_counter() - x_ref = _triangular_solve_reference( - A_dense, b_ref, lower=lower, op_mode=op_mode - ) - elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if not A_csr.is_cuda: + raise RuntimeError("torch.sparse.spsolve CUDA path is unavailable") + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_ref = sparse_spsolve(A_csr, b) + e1.record() + torch.cuda.synchronize() + elapsed_ms = e0.elapsed_time(e1) return x_ref.to(b.dtype), elapsed_ms def _benchmark_pytorch_reference(data, indices, indptr, shape, b, *, lower, op_mode): - sparse_err = None + del lower try: x_ref, ms = _benchmark_pytorch_sparse_reference( data, indices, indptr, shape, b, op_mode=op_mode ) - return x_ref, ms, None, None - except RuntimeError as e: - sparse_err = e - if "out of memory" in str(e).lower(): - if torch.cuda.is_available(): - torch.cuda.empty_cache() - except NotImplementedError as e: - sparse_err = e - - try: - data_cpu = data.detach().cpu() - indices_cpu = indices.to(torch.int64).detach().cpu() - indptr_cpu = indptr.to(torch.int64).detach().cpu() - b_cpu = b.detach().cpu() - x_ref_cpu, ms_cpu = _benchmark_pytorch_sparse_reference( - data_cpu, indices_cpu, indptr_cpu, shape, b_cpu, op_mode=op_mode - ) - return ( - x_ref_cpu.to(b.device, dtype=b.dtype), - None, - ms_cpu, - "PyTorch ref used CPU sparse.spsolve", - ) - except RuntimeError as e: - sparse_err = e - except NotImplementedError: - pass - - try: - data_cpu = data.detach().cpu() - indices_cpu = indices.to(torch.int64).detach().cpu() - indptr_cpu = indptr.to(torch.int64).detach().cpu() - b_cpu = b.detach().cpu() - x_ref_cpu, ms_cpu = _benchmark_pytorch_dense_reference( - data_cpu, - indices_cpu, - indptr_cpu, - shape, - b_cpu, - lower=lower, - op_mode=op_mode, - ) - note = "PyTorch sparse ref unavailable; CPU dense solve_triangular fallback" - if sparse_err is None: - note = "PyTorch ref used CPU dense solve_triangular" - return ( - x_ref_cpu.to(b.device, dtype=b.dtype), - None, - ms_cpu, - note, - ) - except RuntimeError as e: - if "out of memory" in str(e).lower(): - return None, None, None, "PyTorch sparse ref unavailable; CPU dense ref OOM; skipped" - raise + return x_ref, ms, None + except Exception as e: + if "out of memory" in str(e).lower() and torch.cuda.is_available(): + torch.cuda.empty_cache() + return None, None, f"PyTorch CUDA sparse ref unavailable; skipped ({e})" def _cupy_ref_inputs(data, b): return data, b -def _compare_view(tensor, value_dtype): - return tensor - - def _supported_csr_full_ops(value_dtype, index_dtype): if value_dtype not in CSR_FULL_VALUE_DTYPES: return [] @@ -345,28 +236,6 @@ def _supported_csr_full_ops(value_dtype, index_dtype): return [] -def _fill_mode_name(lower): - return "LOWER" if lower else "UPPER" - - -def _diag_type_name(unit_diagonal): - return "UNIT" if unit_diagonal else "NON_UNIT" - - -def _triton_alg_name(fmt, op_mode=None, coo_mode=None): - if fmt == "CSR": - if op_mode in ("TRANS", "CONJ"): - return "TRITON_CSR_LEVEL_TRANSPOSE_FAMILY" - return "TRITON_CSR_LEVEL" - if coo_mode: - return f"TRITON_COO_{str(coo_mode).upper()}" - return "TRITON_COO_LEVEL" - - -def _cusparse_alg_name(): - return "CUPY_SPSOLVE_TRIANGULAR" - - def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True): """Build a well-conditioned triangular CSR for real and complex dtypes.""" max_bandwidth = max(4, min(n, 16)) @@ -442,22 +311,6 @@ def _build_random_triangular_csr(n, value_dtype, index_dtype, device, lower=True indices = cols_t.to(index_dtype) return vals_t, indices, indptr, (n, n) - -def _csr_to_dense(data, indices, indptr, shape): - n_rows, n_cols = shape - row_ind = torch.repeat_interleave( - torch.arange(n_rows, device=data.device, dtype=torch.int64), - indptr[1:] - indptr[:-1], - ) - coo = torch.sparse_coo_tensor( - torch.stack([row_ind, indices.to(torch.int64)]), - data, - (n_rows, n_cols), - device=data.device, - ).coalesce() - return coo.to_dense() - - def _csr_to_coo(data, indices, indptr, shape, index_dtype=torch.int64): n_rows = int(shape[0]) row = torch.repeat_interleave( @@ -947,7 +800,7 @@ def run_spsv_synthetic_all(lower=True): ) torch.cuda.synchronize() - x_pt, pytorch_ms, _pytorch_cpu_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_pt, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -1119,7 +972,7 @@ def _finalize_csv_row( err_pt = None ok_pt = False pt_skip_reason = None - x_ref, pytorch_ms, _pytorch_cpu_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_ref, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -1129,8 +982,8 @@ def _finalize_csv_row( op_mode="NON", ) if x_ref is not None: - x_cmp = _compare_view(x, value_dtype) - x_ref_cmp = _compare_view(x_ref, value_dtype) + x_cmp = x + x_ref_cmp = x_ref err_pt = ( float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) if n_rows > 0 @@ -1203,8 +1056,8 @@ def _finalize_csv_row( c1.synchronize() cupy_ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS x_cu_t = torch.utils.dlpack.from_dlpack(x_cu.toDlpack()) - x_cmp = _compare_view(x, value_dtype) - x_cu_cmp = _compare_view(x_cu_t, value_dtype) + x_cmp = x + x_cu_cmp = x_cu_t err_cu = ( float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) if n_rows > 0 @@ -1331,7 +1184,7 @@ def _finalize_csv_row_csr_full( err_pt = None ok_pt = False pt_skip_reason = None - x_ref, pytorch_ms, _pytorch_cpu_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_ref, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -1341,8 +1194,8 @@ def _finalize_csv_row_csr_full( op_mode=op_mode, ) if x_ref is not None: - x_cmp = _compare_view(x, value_dtype) - x_ref_cmp = _compare_view(x_ref, value_dtype) + x_cmp = x + x_ref_cmp = x_ref err_pt = ( float(torch.max(torch.abs(x_cmp - x_ref_cmp)).item()) if n_rows > 0 @@ -1358,8 +1211,8 @@ def _finalize_csv_row_csr_full( data, indices, indptr, shape, b, op_mode, lower ) if x_cu_t is not None: - x_cmp = _compare_view(x, value_dtype) - x_cu_cmp = _compare_view(x_cu_t, value_dtype) + x_cmp = x + x_cu_cmp = x_cu_t err_cu = ( float(torch.max(torch.abs(x_cmp - x_cu_cmp)).item()) if n_rows > 0 From 4ac021880db82e5868cf2fc4758ede7120443b56 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 16:43:31 +0800 Subject: [PATCH 14/22] pytorch --- tests/test_spsv.py | 123 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 111 insertions(+), 12 deletions(-) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index e8964a1..4e146a5 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -1,8 +1,8 @@ """SpSV tests: synthetic triangular systems and optional .mtx (CSR/COO). -与 CSR 相同的计时列、PyTorch CUDA sparse 参考、CSV 字段与 PASS 判定;COO 测试时 -CuPy 基线使用 ``coo_matrix``(与 FlagSparse COO 输入同构),CSR 测试时仍用 -``csr_matrix``。 +与 CSR 相同的计时列、PyTorch CUDA 参考(优先 sparse,必要时 dense fallback)、 +CSV 字段与 PASS 判定;COO 测试时 CuPy 基线使用 ``coo_matrix``(与 FlagSparse +COO 输入同构),CSR 测试时仍用 ``csr_matrix``。 """ import argparse @@ -39,6 +39,7 @@ ITERS = 20 SPSV_TRIANGULAR_DIAG_DOMINANCE = 4.0 +SPSV_PYTORCH_DENSE_GPU_SAFETY_FACTOR = 3.0 # CSR 完整组合覆盖(在原 csv-csr 逻辑外新增,不影响原入口) CSR_FULL_VALUE_DTYPES = [ torch.float32, @@ -169,6 +170,24 @@ def _matrix_market_value(parts, mm_field): return 1.0 raise ValueError("MatrixMarket entry is missing a numeric value") + +def _csr_to_dense(data, indices, indptr, shape): + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows == 0 or n_cols == 0: + return torch.zeros((n_rows, n_cols), dtype=data.dtype, device=data.device) + row_ind = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + coo = torch.sparse_coo_tensor( + torch.stack([row_ind, indices.to(torch.int64)]), + data, + (n_rows, n_cols), + device=data.device, + ).coalesce() + return coo.to_dense() + + def _build_csr_tensor_for_op(data, indices, indptr, shape, op_mode): if op_mode == "TRANS": data_eff, indices_eff, indptr_eff = _csr_transpose( @@ -191,6 +210,37 @@ def _build_csr_tensor_for_op(data, indices, indptr, shape, op_mode): ) +def _apply_dense_triangular_op(A_dense, b, *, lower, op_mode): + if op_mode == "TRANS": + A_eff = A_dense.transpose(0, 1) + upper = lower + elif op_mode == "CONJ": + A_eff = A_dense.transpose(0, 1).conj() if torch.is_complex(A_dense) else A_dense.transpose(0, 1) + upper = lower + else: + A_eff = A_dense + upper = not lower + return torch.linalg.solve_triangular(A_eff, b.unsqueeze(1), upper=upper).squeeze(1) + + +def _gpu_dense_ref_fits(shape, dtype): + if not torch.cuda.is_available(): + return False, "CUDA unavailable" + element_size = torch.empty((), dtype=dtype).element_size() + dense_bytes = int(shape[0]) * int(shape[1]) * element_size + rhs_bytes = int(shape[0]) * element_size + estimated_bytes = int(dense_bytes * SPSV_PYTORCH_DENSE_GPU_SAFETY_FACTOR + rhs_bytes * 4) + try: + free_bytes, _total_bytes = torch.cuda.mem_get_info() + except Exception as e: + return False, f"cannot query CUDA memory ({e})" + if estimated_bytes > free_bytes: + need_gib = estimated_bytes / (1024 ** 3) + free_gib = free_bytes / (1024 ** 3) + return False, f"CUDA dense fallback too large ({need_gib:.1f} GiB est > {free_gib:.1f} GiB free)" + return True, None + + def _benchmark_pytorch_sparse_reference(data, indices, indptr, shape, b, *, op_mode): sparse_spsolve = getattr(torch.sparse, "spsolve", None) if sparse_spsolve is None: @@ -209,17 +259,52 @@ def _benchmark_pytorch_sparse_reference(data, indices, indptr, shape, b, *, op_m return x_ref.to(b.dtype), elapsed_ms +def _benchmark_pytorch_dense_reference(data, indices, indptr, shape, b, *, lower, op_mode): + if not data.is_cuda: + raise RuntimeError("PyTorch CUDA dense fallback requires CUDA tensors") + fits, reason = _gpu_dense_ref_fits(shape, b.dtype) + if not fits: + raise RuntimeError(reason) + A_dense = _csr_to_dense(data, indices, indptr, shape) + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + x_ref = _apply_dense_triangular_op(A_dense, b, lower=lower, op_mode=op_mode) + e1.record() + torch.cuda.synchronize() + elapsed_ms = e0.elapsed_time(e1) + return x_ref.to(b.dtype), elapsed_ms + + def _benchmark_pytorch_reference(data, indices, indptr, shape, b, *, lower, op_mode): - del lower + sparse_err = None try: x_ref, ms = _benchmark_pytorch_sparse_reference( data, indices, indptr, shape, b, op_mode=op_mode ) - return x_ref, ms, None + return x_ref, ms, "gpu_sparse", None except Exception as e: + sparse_err = e if "out of memory" in str(e).lower() and torch.cuda.is_available(): torch.cuda.empty_cache() - return None, None, f"PyTorch CUDA sparse ref unavailable; skipped ({e})" + try: + x_ref, ms = _benchmark_pytorch_dense_reference( + data, indices, indptr, shape, b, lower=lower, op_mode=op_mode + ) + return ( + x_ref, + ms, + "gpu_dense", + f"sparse unavailable ({sparse_err}); using CUDA dense solve_triangular", + ) + except Exception as dense_err: + if "out of memory" in str(dense_err).lower() and torch.cuda.is_available(): + torch.cuda.empty_cache() + reason = f"sparse unavailable ({sparse_err}); dense unavailable ({dense_err})" + if sparse_err is None: + reason = f"CUDA ref unavailable ({dense_err})" + return None, None, "unavailable", reason def _cupy_ref_inputs(data, b): @@ -800,7 +885,7 @@ def run_spsv_synthetic_all(lower=True): ) torch.cuda.synchronize() - x_pt, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_pt, pytorch_ms, pt_backend, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -876,6 +961,8 @@ def run_spsv_synthetic_all(lower=True): f"{(f'{fs_vs_cu:.2f}' if fs_vs_cu is not None else 'N/A'):>10} " f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" ) + if pt_backend and pt_backend != "gpu_sparse": + print(f" NOTE: pt_backend={pt_backend}") if pt_skip_reason: print(f" NOTE: {pt_skip_reason}") print("-" * 110) @@ -972,7 +1059,7 @@ def _finalize_csv_row( err_pt = None ok_pt = False pt_skip_reason = None - x_ref, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_ref, pytorch_ms, pt_backend, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -1090,6 +1177,7 @@ def _finalize_csv_row( "triton_time_total_ms": _sum_ms(analysis_ms, t_ms), "cusparse_solve_ms": cupy_ms, "pytorch_solve_ms": pytorch_ms, + "pytorch_backend": pt_backend, "cusparse/triton": _safe_ratio(cupy_ms, t_ms), "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), "pt_status": _status_str(ok_pt, err_pt is not None), @@ -1184,7 +1272,7 @@ def _finalize_csv_row_csr_full( err_pt = None ok_pt = False pt_skip_reason = None - x_ref, pytorch_ms, pt_skip_reason = _benchmark_pytorch_reference( + x_ref, pytorch_ms, pt_backend, pt_skip_reason = _benchmark_pytorch_reference( data, indices, indptr, @@ -1239,6 +1327,7 @@ def _finalize_csv_row_csr_full( "triton_time_total_ms": _sum_ms(analysis_ms, t_ms), "cusparse_solve_ms": cupy_ms, "pytorch_solve_ms": pytorch_ms, + "pytorch_backend": pt_backend, "cusparse/triton": _safe_ratio(cupy_ms, t_ms), "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), "pt_status": _status_str(ok_pt, err_pt is not None), @@ -1286,7 +1375,8 @@ def run_all_supported_spsv_csr_csv( f"Value dtype: {_dtype_name(value_dtype)} | Index dtype: {_dtype_name(index_dtype)} | CSR | triA={'LOWER' if lower else 'UPPER'} | opA={op_mode}" ) print( - "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, PyTorch(ms)=CUDA sparse solve only." + "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, " + "PyTorch(ms)=CUDA sparse solve preferred, CUDA dense fallback if needed." ) print( "RHS is generated directly, matching Library-main's SpSV test style. " @@ -1326,6 +1416,8 @@ def run_all_supported_spsv_csr_csv( f"{_fmt_speedup(cupy_ms, t_ms):>10} {_fmt_speedup(pytorch_ms, t_ms):>10} " f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) + if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_sparse": + print(f" NOTE: pt_backend={row['pytorch_backend']}") if pt_skip: print(f" NOTE: {pt_skip}") except Exception as e: @@ -1345,6 +1437,7 @@ def run_all_supported_spsv_csr_csv( "triton_time_total_ms": None, "cusparse_solve_ms": None, "pytorch_solve_ms": None, + "pytorch_backend": None, "cusparse/triton": None, "pytorch/triton": None, "pt_status": "N/A", @@ -1383,6 +1476,7 @@ def run_all_supported_spsv_csr_csv( "triton_time_total_ms", "cusparse_solve_ms", "pytorch_solve_ms", + "pytorch_backend", "cusparse/triton", "pytorch/triton", "pt_status", @@ -1427,7 +1521,8 @@ def run_all_dtypes_spsv_coo_csv( f" triA={'LOWER' if lower else 'UPPER'} coo_mode={coo_mode}" ) print( - "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, PyTorch(ms)=CUDA sparse solve only. " + "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, " + "PyTorch(ms)=CUDA sparse solve preferred, CUDA dense fallback if needed. " "RHS is generated directly, matching Library-main's SpSV test style." ) print( @@ -1466,6 +1561,8 @@ def run_all_dtypes_spsv_coo_csv( f"{_fmt_speedup(cupy_ms, t_ms):>10} {_fmt_speedup(pytorch_ms, t_ms):>10} " f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) + if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_sparse": + print(f" NOTE: pt_backend={row['pytorch_backend']}") if pt_skip: print(f" NOTE: {pt_skip}") except Exception as e: @@ -1485,6 +1582,7 @@ def run_all_dtypes_spsv_coo_csv( "triton_time_total_ms": None, "cusparse_solve_ms": None, "pytorch_solve_ms": None, + "pytorch_backend": None, "cusparse/triton": None, "pytorch/triton": None, "pt_status": "N/A", @@ -1522,6 +1620,7 @@ def run_all_dtypes_spsv_coo_csv( "triton_time_total_ms", "cusparse_solve_ms", "pytorch_solve_ms", + "pytorch_backend", "cusparse/triton", "pytorch/triton", "pt_status", @@ -1627,7 +1726,7 @@ def _check_one_csr_transpose_case(path, value_dtype, index_dtype, op_mode, devic ref_err = None ref_ok = None - x_ref, _, _ = _benchmark_pytorch_reference( + x_ref, _, _, _ = _benchmark_pytorch_reference( data, indices, indptr, From aa9689bff36fd5ca6597b0fc8d6c1c79fcc7c290 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 16:58:57 +0800 Subject: [PATCH 15/22] Refine SpSV PyTorch fallback notes and output --- tests/test_spsv.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 4e146a5..3baaa12 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -235,9 +235,14 @@ def _gpu_dense_ref_fits(shape, dtype): except Exception as e: return False, f"cannot query CUDA memory ({e})" if estimated_bytes > free_bytes: + dense_gib = dense_bytes / (1024 ** 3) need_gib = estimated_bytes / (1024 ** 3) free_gib = free_bytes / (1024 ** 3) - return False, f"CUDA dense fallback too large ({need_gib:.1f} GiB est > {free_gib:.1f} GiB free)" + return ( + False, + "CUDA dense fallback too large " + f"(dense matrix {dense_gib:.1f} GiB, est {need_gib:.1f} GiB > {free_gib:.1f} GiB free)", + ) return True, None @@ -961,10 +966,8 @@ def run_spsv_synthetic_all(lower=True): f"{(f'{fs_vs_cu:.2f}' if fs_vs_cu is not None else 'N/A'):>10} " f"{status:>8} {_fmt_err(err_pt):>12} {_fmt_err(err_cu):>12}" ) - if pt_backend and pt_backend != "gpu_sparse": - print(f" NOTE: pt_backend={pt_backend}") - if pt_skip_reason: - print(f" NOTE: {pt_skip_reason}") + # Synthetic benchmark keeps the main row compact; PyTorch fallback notes + # are only emitted in matrix CSV runs where failed reference checks matter. print("-" * 110) print() @@ -1376,7 +1379,7 @@ def run_all_supported_spsv_csr_csv( ) print( "Formats: FlagSparse=CSR, cuSPARSE=CSR ref, " - "PyTorch(ms)=CUDA sparse solve preferred, CUDA dense fallback if needed." + "PyTorch(ms)=CUDA reference (sparse if available, else dense triangular solve)" ) print( "RHS is generated directly, matching Library-main's SpSV test style. " @@ -1416,10 +1419,11 @@ def run_all_supported_spsv_csr_csv( f"{_fmt_speedup(cupy_ms, t_ms):>10} {_fmt_speedup(pytorch_ms, t_ms):>10} " f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) - if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_sparse": - print(f" NOTE: pt_backend={row['pytorch_backend']}") - if pt_skip: - print(f" NOTE: {pt_skip}") + if status in ("FAIL", "REF_FAIL"): + if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_sparse": + print(f" NOTE: pt_backend={row['pytorch_backend']}") + if pt_skip: + print(f" NOTE: {pt_skip}") except Exception as e: err_msg = str(e) status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" @@ -1522,7 +1526,7 @@ def run_all_dtypes_spsv_coo_csv( ) print( "Formats: FlagSparse=COO SpSV, cuSPARSE=COO ref, " - "PyTorch(ms)=CUDA sparse solve preferred, CUDA dense fallback if needed. " + "PyTorch(ms)=CUDA reference (sparse if available, else dense triangular solve). " "RHS is generated directly, matching Library-main's SpSV test style." ) print( @@ -1561,10 +1565,11 @@ def run_all_dtypes_spsv_coo_csv( f"{_fmt_speedup(cupy_ms, t_ms):>10} {_fmt_speedup(pytorch_ms, t_ms):>10} " f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) - if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_sparse": - print(f" NOTE: pt_backend={row['pytorch_backend']}") - if pt_skip: - print(f" NOTE: {pt_skip}") + if status in ("FAIL", "REF_FAIL"): + if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_sparse": + print(f" NOTE: pt_backend={row['pytorch_backend']}") + if pt_skip: + print(f" NOTE: {pt_skip}") except Exception as e: err_msg = str(e) status = "SKIP" if "SpSV requires square matrices" in err_msg else "ERROR" From 9f8ef1b8e519481707041c4da6647c7e2f01c990 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 18:32:14 +0800 Subject: [PATCH 16/22] spsm-opt --- src/flagsparse/sparse_operations/spsm.py | 530 +++++++++++++------- src/flagsparse/sparse_operations/spsv.py | 89 +++- tests/test_spsm.py | 607 +++++++++++++++-------- 3 files changed, 826 insertions(+), 400 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsm.py b/src/flagsparse/sparse_operations/spsm.py index 28e26d3..8254a4c 100644 --- a/src/flagsparse/sparse_operations/spsm.py +++ b/src/flagsparse/sparse_operations/spsm.py @@ -1,13 +1,9 @@ """Sparse triangular matrix-matrix solve (SpSM) for CSR/COO.""" -from ._common import * +from collections import OrderedDict -import time +from ._common import * -try: - from cupyx.scipy.sparse.linalg import spsolve_triangular as cpx_spsolve_triangular -except Exception: - cpx_spsolve_triangular = None SUPPORTED_SPSM_VALUE_DTYPES = (torch.float32, torch.float64) SUPPORTED_SPSM_INDEX_DTYPES = (torch.int32, torch.int64) @@ -17,6 +13,12 @@ ("coo", torch.float32, torch.int32), ("coo", torch.float64, torch.int32), ) +_SPSM_PREPROCESS_CACHE = OrderedDict() +_SPSM_PREPROCESS_CACHE_SIZE = 8 + + +def _clear_spsm_preprocess_cache(): + _SPSM_PREPROCESS_CACHE.clear() def _is_non_transpose(op): @@ -69,44 +71,24 @@ def _prepare_spsm_csr_inputs(data, indices, indptr, B, shape, opA, opB, major): raise TypeError("indices dtype must be torch.int32 or torch.int64") if indptr.dtype not in SUPPORTED_SPSM_INDEX_DTYPES: raise TypeError("indptr dtype must be torch.int32 or torch.int64") + indices64 = indices.to(torch.int64).contiguous() + indptr64 = indptr.to(torch.int64).contiguous() if data.numel() > 0: - if int(indices.to(torch.int64).max().item()) > _INDEX_LIMIT_INT32: + if bool(torch.any(indices64 < 0).item()): + raise IndexError("indices must be non-negative") + if int(indices64.max().item()) >= n_cols: + raise IndexError("indices out of range") + if int(indices64.max().item()) > _INDEX_LIMIT_INT32: raise ValueError("index value exceeds int32 kernel range") + if indptr64.numel() > 0: + if int(indptr64[0].item()) != 0: + raise ValueError("indptr[0] must be 0") + if int(indptr64[-1].item()) != int(data.numel()): + raise ValueError("indptr[-1] must equal nnz") + if bool(torch.any(indptr64[1:] < indptr64[:-1]).item()): + raise ValueError("indptr must be non-decreasing") _validate_spsm_non_trans_combo("csr", data.dtype, torch.int32) - return ( - data.contiguous(), - indices.to(torch.int64).contiguous(), - indptr.to(torch.int64).contiguous(), - B.contiguous(), - n_rows, - ) - - -def _coo_to_csr_sorted_unique(data, row, col, n_rows, n_cols): - if data.numel() == 0: - return ( - data, - torch.empty(0, dtype=torch.int64, device=data.device), - torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device), - ) - key = row * max(1, n_cols) + col - try: - order = torch.argsort(key, stable=True) - except TypeError: - order = torch.argsort(key) - key_s = key[order] - data_s = data[order] - unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) - out_nnz = unique_key.numel() - data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) - data_u.scatter_add_(0, inverse, data_s) - row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") - col_u = unique_key - row_u * max(1, n_cols) - indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) - if out_nnz > 0: - nnz_per_row = torch.bincount(row_u, minlength=n_rows) - indptr[1:] = torch.cumsum(nnz_per_row, dim=0) - return data_u, col_u.to(torch.int64), indptr + return data.contiguous(), indices64, indptr64, B.contiguous(), n_rows, n_cols def _prepare_spsm_coo_inputs(data, row, col, B, shape, opA, opB, major): @@ -139,10 +121,50 @@ def _prepare_spsm_coo_inputs(data, row, col, B, shape, opA, opB, major): raise IndexError("row index out of range") if int(col64.max().item()) >= n_cols: raise IndexError("col index out of range") + if max(int(row64.max().item()), int(col64.max().item())) > _INDEX_LIMIT_INT32: + raise ValueError("row/col value exceeds int32 kernel range") _validate_spsm_non_trans_combo("coo", data.dtype, torch.int32) return data.contiguous(), row64, col64, B.contiguous(), n_rows, n_cols +def _tensor_cache_token(tensor): + try: + storage_ptr = int(tensor.untyped_storage().data_ptr()) + except Exception: + storage_ptr = 0 + return ( + str(tensor.device), + str(tensor.dtype), + tuple(int(v) for v in tensor.shape), + int(tensor.numel()), + storage_ptr, + int(getattr(tensor, "_version", 0)), + ) + + +def _spsm_cache_get(cache, key): + value = cache.get(key) + if value is not None: + cache.move_to_end(key) + return value + + +def _spsm_cache_put(cache, key, value, max_entries): + cache[key] = value + cache.move_to_end(key) + while len(cache) > max_entries: + cache.popitem(last=False) + + +def _spsm_preprocess_cache_key(fmt_name, tensors, shape, lower, unit_diagonal): + return ( + str(fmt_name).lower(), + bool(lower), + bool(unit_diagonal), + int(shape[0]), + int(shape[1]), + *(_tensor_cache_token(t) for t in tensors), + ) @triton.jit @@ -318,36 +340,69 @@ def _build_spsm_levels(indptr, indices, n_rows, lower=True): if lower: for i in range(n_rows): - s = int(indptr_h[i].item()) - e = int(indptr_h[i + 1].item()) - lvl = 0 - for p in range(s, e): - c = int(indices_h[p].item()) - if c < i: - lvl = max(lvl, levels[c] + 1) - levels[i] = lvl + start = int(indptr_h[i].item()) + end = int(indptr_h[i + 1].item()) + level = 0 + for p in range(start, end): + col = int(indices_h[p].item()) + if col < i: + level = max(level, levels[col] + 1) + levels[i] = level else: for i in range(n_rows - 1, -1, -1): - s = int(indptr_h[i].item()) - e = int(indptr_h[i + 1].item()) - lvl = 0 - for p in range(s, e): - c = int(indices_h[p].item()) - if c > i: - lvl = max(lvl, levels[c] + 1) - levels[i] = lvl + start = int(indptr_h[i].item()) + end = int(indptr_h[i + 1].item()) + level = 0 + for p in range(start, end): + col = int(indices_h[p].item()) + if col > i: + level = max(level, levels[col] + 1) + levels[i] = level max_level = max(levels) buckets = [[] for _ in range(max_level + 1)] - for r, lv in enumerate(levels): - buckets[lv].append(r) + for row, level in enumerate(levels): + buckets[level].append(row) + + device = indptr.device + return [torch.tensor(rows, dtype=torch.int32, device=device) for rows in buckets if rows] + + +def _build_spsm_frontiers(indptr, indices, levels, lower=True): + if not levels: + return [] + indptr_h = indptr.to(torch.int64).cpu() + indices_h = indices.to(torch.int64).cpu() device = indptr.device - return [ - torch.tensor(rows, dtype=torch.int32, device=device) - for rows in buckets - if rows - ] + frontier_rows = [] + frontier_row_set = set() + merged = [] + + def _flush_frontier(): + nonlocal frontier_rows, frontier_row_set + if frontier_rows: + merged.append(torch.tensor(frontier_rows, dtype=torch.int32, device=device)) + frontier_rows = [] + frontier_row_set = set() + + for rows_lv in levels: + for row in rows_lv.to(torch.int64).cpu().tolist(): + start = int(indptr_h[row].item()) + end = int(indptr_h[row + 1].item()) + depends_on_frontier = False + for p in range(start, end): + col = int(indices_h[p].item()) + is_dep = col < row if lower else col > row + if is_dep and col in frontier_row_set: + depends_on_frontier = True + break + if depends_on_frontier: + _flush_frontier() + frontier_rows.append(int(row)) + frontier_row_set.add(int(row)) + _flush_frontier() + return merged def _auto_spsm_launch_config(indptr, block_nnz=None, max_segments=None): @@ -403,11 +458,13 @@ def _auto_rhs_block(n_rhs): return 64 -def _coo_sort_unique_and_rowptr(data, row64, col64, n_rows, n_cols): +def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): if data.numel() == 0: - row_ptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) - return data, torch.empty(0, dtype=torch.int64, device=data.device), row_ptr - + return ( + data, + torch.empty(0, dtype=torch.int64, device=data.device), + torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device), + ) key = row64 * max(1, n_cols) + col64 try: order = torch.argsort(key, stable=True) @@ -415,31 +472,101 @@ def _coo_sort_unique_and_rowptr(data, row64, col64, n_rows, n_cols): order = torch.argsort(key) key_s = key[order] data_s = data[order] - unique_key, inverse = torch.unique_consecutive(key_s, return_inverse=True) out_nnz = unique_key.numel() data_u = torch.zeros(out_nnz, dtype=data.dtype, device=data.device) data_u.scatter_add_(0, inverse, data_s) - row_u = torch.div(unique_key, max(1, n_cols), rounding_mode="floor") col_u = unique_key - row_u * max(1, n_cols) - row_ptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) - nnz_per_row = torch.bincount(row_u, minlength=n_rows) - row_ptr[1:] = torch.cumsum(nnz_per_row, dim=0) - return data_u, col_u.to(torch.int64), row_ptr + indptr = torch.zeros(n_rows + 1, dtype=torch.int64, device=data.device) + if out_nnz > 0: + nnz_per_row = torch.bincount(row_u, minlength=n_rows) + indptr[1:] = torch.cumsum(nnz_per_row, dim=0) + return data_u, col_u.to(torch.int64), indptr + + +def _prepare_spsm_csr_system(data, indices64, indptr64, n_rows, lower): + levels = _build_spsm_levels(indptr64, indices64, n_rows, lower=lower) + default_block_nnz, default_max_segments = _auto_spsm_launch_config(indptr64) + return { + "kernel_data": data, + "kernel_indices64": indices64, + "kernel_indices32": indices64.to(torch.int32), + "kernel_indptr64": indptr64, + "launch_groups": _build_spsm_frontiers(indptr64, indices64, levels, lower=lower), + "default_block_nnz": default_block_nnz, + "default_max_segments": default_max_segments, + "lower_eff": bool(lower), + } + + +def _prepare_spsm_coo_system(data, row64, col64, n_rows, n_cols, lower): + data_u, col_u64, row_ptr = _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols) + levels = _build_spsm_levels(row_ptr, col_u64, n_rows, lower=lower) + default_block_nnz, default_max_segments = _auto_spsm_launch_config(row_ptr) + return { + "kernel_data": data_u, + "kernel_cols64": col_u64, + "kernel_cols32": col_u64.to(torch.int32), + "kernel_row_ptr64": row_ptr, + "launch_groups": _build_spsm_frontiers(row_ptr, col_u64, levels, lower=lower), + "default_block_nnz": default_block_nnz, + "default_max_segments": default_max_segments, + "lower_eff": bool(lower), + } + + +def _resolve_spsm_csr_runtime(data, indices, indptr, B, shape, lower, unit_diagonal, opA, opB, major): + data, indices64, indptr64, B, n_rows, n_cols = _prepare_spsm_csr_inputs( + data, indices, indptr, B, shape, opA, opB, major + ) + cache_key = _spsm_preprocess_cache_key( + "csr", + (data, indices64, indptr64), + shape, + lower, + unit_diagonal, + ) + solve_plan = _spsm_cache_get(_SPSM_PREPROCESS_CACHE, cache_key) + if solve_plan is None: + solve_plan = _prepare_spsm_csr_system(data, indices64, indptr64, n_rows, lower) + _spsm_cache_put(_SPSM_PREPROCESS_CACHE, cache_key, solve_plan, _SPSM_PREPROCESS_CACHE_SIZE) + return data, B, n_rows, n_cols, solve_plan + + +def _resolve_spsm_coo_runtime(data, row, col, B, shape, lower, unit_diagonal, opA, opB, major): + data, row64, col64, B, n_rows, n_cols = _prepare_spsm_coo_inputs( + data, row, col, B, shape, opA, opB, major + ) + cache_key = _spsm_preprocess_cache_key( + "coo", + (data, row64, col64), + shape, + lower, + unit_diagonal, + ) + solve_plan = _spsm_cache_get(_SPSM_PREPROCESS_CACHE, cache_key) + if solve_plan is None: + solve_plan = _prepare_spsm_coo_system(data, row64, col64, n_rows, n_cols, lower) + _spsm_cache_put(_SPSM_PREPROCESS_CACHE, cache_key, solve_plan, _SPSM_PREPROCESS_CACHE_SIZE) + return data, B, n_rows, n_cols, solve_plan def _run_spsm_csr_core( data, - indices64, + indices32, indptr64, rhs, n_rows, + *, lower=True, unit_diagonal=False, block_nnz=None, max_segments=None, block_rhs=None, + launch_groups=None, + block_nnz_use=None, + max_segments_use=None, ): if rhs.ndim != 2: raise ValueError("rhs must be 2D") @@ -451,11 +578,13 @@ def _run_spsm_csr_core( if n_rows == 0 or n_rhs == 0: return x - indices32 = indices64.to(torch.int32) - levels = _build_spsm_levels(indptr64, indices32, n_rows, lower=lower) - block_nnz_use, max_segments_use = _auto_spsm_launch_config( - indptr64, block_nnz=block_nnz, max_segments=max_segments - ) + if launch_groups is None: + levels = _build_spsm_levels(indptr64, indices32, n_rows, lower=lower) + launch_groups = _build_spsm_frontiers(indptr64, indices32, levels, lower=lower) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsm_launch_config( + indptr64, block_nnz=block_nnz, max_segments=max_segments + ) block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) if block_rhs_use <= 0: raise ValueError("block_rhs must be positive") @@ -463,7 +592,7 @@ def _run_spsm_csr_core( use_fp64 = data.dtype == torch.float64 diag_eps = 1e-12 if use_fp64 else 1e-6 - for rows_lv in levels: + for rows_lv in launch_groups: n_lv = rows_lv.numel() if n_lv == 0: continue @@ -492,16 +621,19 @@ def _run_spsm_csr_core( def _run_spsm_coo_core( data, - row64, - col64, + cols32, + row_ptr64, rhs, n_rows, - n_cols, + *, lower=True, unit_diagonal=False, block_nnz=None, max_segments=None, block_rhs=None, + launch_groups=None, + block_nnz_use=None, + max_segments_use=None, ): if rhs.ndim != 2: raise ValueError("rhs must be 2D") @@ -513,12 +645,13 @@ def _run_spsm_coo_core( if n_rows == 0 or n_rhs == 0: return x - data_u, col_u64, row_ptr = _coo_sort_unique_and_rowptr(data, row64, col64, n_rows, n_cols) - cols32 = col_u64.to(torch.int32) - levels = _build_spsm_levels(row_ptr, cols32, n_rows, lower=lower) - block_nnz_use, max_segments_use = _auto_spsm_launch_config( - row_ptr, block_nnz=block_nnz, max_segments=max_segments - ) + if launch_groups is None: + levels = _build_spsm_levels(row_ptr64, cols32, n_rows, lower=lower) + launch_groups = _build_spsm_frontiers(row_ptr64, cols32, levels, lower=lower) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsm_launch_config( + row_ptr64, block_nnz=block_nnz, max_segments=max_segments + ) block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) if block_rhs_use <= 0: raise ValueError("block_rhs must be positive") @@ -526,14 +659,14 @@ def _run_spsm_coo_core( use_fp64 = data.dtype == torch.float64 diag_eps = 1e-12 if use_fp64 else 1e-6 - for rows_lv in levels: + for rows_lv in launch_groups: n_lv = rows_lv.numel() if n_lv == 0: continue grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use)) _spsm_coo_level_kernel_real[grid]( - data_u, - row_ptr, + data, + row_ptr64, cols32, rhs, x, @@ -568,21 +701,24 @@ def flagsparse_spsm_csr( out=None, return_time=False, ): - data, indices, indptr, B, _ = _prepare_spsm_csr_inputs( - data, indices, indptr, B, shape, opA, opB, major + data, B, n_rows, _n_cols, solve_plan = _resolve_spsm_csr_runtime( + data, indices, indptr, B, shape, lower, unit_diagonal, opA, opB, major ) alpha_t = torch.as_tensor(alpha, dtype=B.dtype, device=B.device) rhs = alpha_t * B torch.cuda.synchronize() t0 = time.perf_counter() x = _run_spsm_csr_core( - data, - indices, - indptr, + solve_plan["kernel_data"], + solve_plan["kernel_indices32"], + solve_plan["kernel_indptr64"], rhs, - int(shape[0]), - lower=lower, + n_rows, + lower=solve_plan["lower_eff"], unit_diagonal=unit_diagonal, + launch_groups=solve_plan["launch_groups"], + block_nnz_use=solve_plan["default_block_nnz"], + max_segments_use=solve_plan["default_max_segments"], ) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 @@ -611,22 +747,24 @@ def flagsparse_spsm_coo( out=None, return_time=False, ): - data, row64, col64, B, n_rows, n_cols = _prepare_spsm_coo_inputs( - data, row, col, B, shape, opA, opB, major + data, B, n_rows, _n_cols, solve_plan = _resolve_spsm_coo_runtime( + data, row, col, B, shape, lower, unit_diagonal, opA, opB, major ) alpha_t = torch.as_tensor(alpha, dtype=B.dtype, device=B.device) rhs = alpha_t * B torch.cuda.synchronize() t0 = time.perf_counter() x = _run_spsm_coo_core( - data, - row64, - col64, + solve_plan["kernel_data"], + solve_plan["kernel_cols32"], + solve_plan["kernel_row_ptr64"], rhs, n_rows, - n_cols, - lower=lower, + lower=solve_plan["lower_eff"], unit_diagonal=unit_diagonal, + launch_groups=solve_plan["launch_groups"], + block_nnz_use=solve_plan["default_block_nnz"], + max_segments_use=solve_plan["default_max_segments"], ) torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 @@ -640,36 +778,58 @@ def flagsparse_spsm_coo( return x -def _cupy_spsm_baseline_from_csr( - data, indices, indptr, B, shape, alpha, lower, unit_diagonal, warmup=10, iters=50 +def _analyze_spsm_csr( + data, + indices, + indptr, + B, + shape, + *, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + clear_cache=False, + return_time=False, ): - if cp is None or cpx_sparse is None or cpx_spsolve_triangular is None: - return None, None, "cupy/cusparse unavailable" - try: - data_cp = _cupy_from_torch(data) - idx_cp = _cupy_from_torch(indices.to(torch.int64)) - ptr_cp = _cupy_from_torch(indptr) - B_cp = _cupy_from_torch(B) - A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) - alpha_cp = cp.asarray(alpha, dtype=B_cp.dtype) - for _ in range(max(0, int(warmup))): - _ = cpx_spsolve_triangular( - A_cp, alpha_cp * B_cp, lower=lower, unit_diagonal=unit_diagonal - ) - cp.cuda.runtime.deviceSynchronize() - e0 = cp.cuda.Event() - e1 = cp.cuda.Event() - e0.record() - for _ in range(max(1, int(iters))): - C_cp = cpx_spsolve_triangular( - A_cp, alpha_cp * B_cp, lower=lower, unit_diagonal=unit_diagonal - ) - e1.record() - e1.synchronize() - ms = cp.cuda.get_elapsed_time(e0, e1) / max(1, int(iters)) - return _torch_from_cupy(C_cp).to(B.dtype), ms, None - except Exception as exc: - return None, None, str(exc) + if clear_cache: + _clear_spsm_preprocess_cache() + torch.cuda.synchronize() + t0 = time.perf_counter() + _resolve_spsm_csr_runtime(data, indices, indptr, B, shape, lower, unit_diagonal, opA, opB, major) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if return_time: + return elapsed_ms + return None + + +def _analyze_spsm_coo( + data, + row, + col, + B, + shape, + *, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + clear_cache=False, + return_time=False, +): + if clear_cache: + _clear_spsm_preprocess_cache() + torch.cuda.synchronize() + t0 = time.perf_counter() + _resolve_spsm_coo_runtime(data, row, col, B, shape, lower, unit_diagonal, opA, opB, major) + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + if return_time: + return elapsed_ms + return None def benchmark_spsm_case( @@ -685,52 +845,59 @@ def benchmark_spsm_case( warmup=10, iters=50, ): - """Benchmark SpSM (NON_TRANS/NON_TRANS, row-major B) against cuSPARSE baseline.""" + """Pure FlagSparse SpSM benchmark entry for one configuration.""" device = torch.device("cuda") data, indices, indptr = _build_random_csr( n_rows, n_rows, nnz, value_dtype, index_dtype, device ) - # Make A triangular and diagonally dominant. row_ids = torch.repeat_interleave( torch.arange(n_rows, device=device, dtype=torch.int64), indptr.to(torch.int64)[1:] - indptr.to(torch.int64)[:-1], ) col_ids = indices.to(torch.int64) tri_mask = (col_ids <= row_ids) if lower else (col_ids >= row_ids) - if tri_mask.numel() > 0: - data = data[tri_mask] - col_ids = col_ids[tri_mask] - row_ids = row_ids[tri_mask] + data = data[tri_mask] + row_ids = row_ids[tri_mask] + col_ids = col_ids[tri_mask] data, col_ids, indptr = _coo_to_csr_sorted_unique(data, row_ids, col_ids, n_rows, n_rows) - # Ensure diagonal exists without densifying A. - row = torch.repeat_interleave( + row_ids = torch.repeat_interleave( torch.arange(n_rows, device=device, dtype=torch.int64), indptr[1:] - indptr[:-1], ) - diag_mask = col_ids == row + diag_mask = col_ids == row_ids diag_present = torch.zeros(n_rows, dtype=torch.bool, device=device) if diag_mask.numel() > 0 and bool(torch.any(diag_mask).item()): - diag_present[row[diag_mask]] = True + diag_present[row_ids[diag_mask]] = True missing_diag = torch.nonzero(~diag_present, as_tuple=False).reshape(-1).to(torch.int64) if missing_diag.numel() > 0: - diag_data = torch.ones(missing_diag.numel(), dtype=value_dtype, device=device) - data = torch.cat([data, diag_data], dim=0) - row = torch.cat([row, missing_diag], dim=0) - col = torch.cat([col_ids, missing_diag], dim=0) - data, col_ids, indptr = _coo_to_csr_sorted_unique(data, row, col, n_rows, n_rows) - row = torch.repeat_interleave( - torch.arange(n_rows, device=device, dtype=torch.int64), - indptr[1:] - indptr[:-1], - ) - col = col_ids.to(torch.int64) + diag_values = torch.ones(missing_diag.numel(), dtype=value_dtype, device=device) + data = torch.cat([data, diag_values], dim=0) + row_ids = torch.cat([row_ids, missing_diag], dim=0) + col_ids = torch.cat([col_ids, missing_diag], dim=0) + data, col_ids, indptr = _coo_to_csr_sorted_unique(data, row_ids, col_ids, n_rows, n_rows) B = torch.randn((n_rows, n_rhs), dtype=value_dtype, device=device).contiguous() shape = (n_rows, n_rows) if str(fmt).lower() == "coo": - triton_op = lambda: flagsparse_spsm_coo( + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + analysis_ms = _analyze_spsm_coo( data, - row, - col, + row_ids, + col_ids, + B, + shape, + lower=lower, + unit_diagonal=unit_diagonal, + clear_cache=True, + return_time=True, + ) + solve_call = lambda: flagsparse_spsm_coo( + data, + row_ids, + col_ids, B, shape, alpha=alpha, @@ -741,7 +908,18 @@ def benchmark_spsm_case( major="row", ) else: - triton_op = lambda: flagsparse_spsm_csr( + analysis_ms = _analyze_spsm_csr( + data, + col_ids.to(index_dtype), + indptr, + B, + shape, + lower=lower, + unit_diagonal=unit_diagonal, + clear_cache=True, + return_time=True, + ) + solve_call = lambda: flagsparse_spsm_csr( data, col_ids.to(index_dtype), indptr, @@ -754,18 +932,7 @@ def benchmark_spsm_case( opB="NON_TRANS", major="row", ) - C_fs, fs_ms = _benchmark_cuda_op(triton_op, warmup=warmup, iters=iters) - atol, rtol = _tolerance_for_dtype(value_dtype) - - C_cu, cu_ms, cu_reason = _cupy_spsm_baseline_from_csr( - data, col_ids, indptr, B, shape, alpha, lower, unit_diagonal, warmup=warmup, iters=iters - ) - cu_ok = None - cu_err = None - if C_cu is not None: - cu_ok = torch.allclose(C_fs, C_cu, atol=atol, rtol=rtol) - cu_err = float(torch.max(torch.abs(C_fs - C_cu)).item()) if C_fs.numel() > 0 else 0.0 - + C_fs, solve_ms = _benchmark_cuda_op(solve_call, warmup=warmup, iters=iters) return { "parameters": { "format": str(fmt).lower(), @@ -779,16 +946,11 @@ def benchmark_spsm_case( "major": "row", }, "performance": { - "flagsparse_ms": fs_ms, - "cusparse_ms": cu_ms, - "speedup_vs_cusparse": (cu_ms / fs_ms if (cu_ms is not None and fs_ms > 0) else None), - }, - "verification": { - "flagsparse_match_cusparse": cu_ok, - "flagsparse_vs_cusparse_max_error": cu_err, - }, - "backend_status": { - "cusparse_unavailable_reason": cu_reason, + "triton_analysis_ms": analysis_ms, + "triton_solve_ms": solve_ms, + "triton_time_total_ms": ( + analysis_ms + solve_ms if analysis_ms is not None and solve_ms is not None else None + ), }, - "samples": {"flagsparse": C_fs, "cusparse": C_cu}, - } \ No newline at end of file + "samples": {"flagsparse": C_fs}, + } diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 0d31194..c89c4e6 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -246,29 +246,89 @@ def _flush_frontier(): return merged +def _build_spsv_reverse_frontiers(indptr, indices, levels, lower=True): + """Greedily merge reverse-topological launch groups for transpose push. + + Rows processed by the currently active reverse frontier push residual + updates into their dependency targets. A candidate row can be merged only + when no active row would update it; otherwise it must be delayed until the + current frontier completes. + """ + if not levels: + return [] + + indptr_h = indptr.to(torch.int64).cpu() + indices_h = indices.to(torch.int64).cpu() + device = indptr.device + dependency_targets = {} + for rows_lv in levels: + for row in rows_lv.to(torch.int64).cpu().tolist(): + start = int(indptr_h[row].item()) + end = int(indptr_h[row + 1].item()) + targets = set() + for p in range(start, end): + col = int(indices_h[p].item()) + is_dep = (col < row) if lower else (col > row) + if is_dep: + targets.add(col) + dependency_targets[int(row)] = targets + + frontier_rows = [] + frontier_targets = set() + merged = [] + + def _flush_frontier(): + nonlocal frontier_rows, frontier_targets + if frontier_rows: + merged.append(torch.tensor(frontier_rows, dtype=torch.int32, device=device)) + frontier_rows = [] + frontier_targets = set() + + for rows_lv in reversed(levels): + for row in rows_lv.to(torch.int64).cpu().tolist(): + if int(row) in frontier_targets: + _flush_frontier() + frontier_rows.append(int(row)) + frontier_targets.update(dependency_targets.get(int(row), ())) + _flush_frontier() + return merged + + def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode): if trans_mode == "N": levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) + default_block_nnz, default_max_segments = _auto_spsv_launch_config(indptr64) return { "solve_kind": "csr_levels", "kernel_data": data, "kernel_indices64": indices64, + "kernel_indices32": indices64.to(torch.int32), "kernel_indptr64": indptr64, "lower_eff": lower, "launch_groups": _build_spsv_frontiers( indptr64, indices64, levels, lower=lower ), + "default_block_nnz": default_block_nnz, + "default_max_segments": default_max_segments, "transpose_conjugate": False, } levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) + default_block_nnz, default_max_segments = _choose_transpose_family_launch_config( + indptr64 + ) return { "solve_kind": "transpose_push", "kernel_data": data, "kernel_indices64": indices64, + "kernel_indices32": indices64.to(torch.int32), "kernel_indptr64": indptr64, "lower_eff": lower, - "launch_groups": list(reversed(levels)), + "launch_groups": _build_spsv_reverse_frontiers( + indptr64, indices64, levels, lower=lower + ), + "default_block_nnz": default_block_nnz, + "default_max_segments": default_max_segments, "transpose_conjugate": trans_mode == "C", } @@ -1237,15 +1297,14 @@ def flagsparse_spsv_csr( solve_kind = solve_plan["solve_kind"] kernel_data = solve_plan["kernel_data"] kernel_indices64 = solve_plan["kernel_indices64"] + kernel_indices32 = solve_plan["kernel_indices32"] kernel_indptr64 = solve_plan["kernel_indptr64"] lower_eff = solve_plan["lower_eff"] launch_groups = solve_plan["launch_groups"] transpose_conjugate = solve_plan["transpose_conjugate"] - kernel_indices = ( - kernel_indices64.to(torch.int32) - if kernel_indices64.dtype != torch.int32 - else kernel_indices64 - ) + default_block_nnz = solve_plan["default_block_nnz"] + default_max_segments = solve_plan["default_max_segments"] + kernel_indices = kernel_indices32 kernel_indptr = kernel_indptr64 compute_dtype = data.dtype data_in = kernel_data @@ -1272,15 +1331,21 @@ def flagsparse_spsv_csr( b_in = b.to(torch.float64) if solve_kind == "transpose_push": - block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( - kernel_indptr, block_nnz=block_nnz, max_segments=max_segments - ) + if block_nnz is None and max_segments is None: + block_nnz_use, max_segments_use = default_block_nnz, default_max_segments + else: + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) vec_real = _triton_spsv_csr_transpose_push_vector vec_complex = _triton_spsv_csr_transpose_push_vector_complex else: - block_nnz_use, max_segments_use = _auto_spsv_launch_config( - kernel_indptr, block_nnz=block_nnz, max_segments=max_segments - ) + if block_nnz is None and max_segments is None: + block_nnz_use, max_segments_use = default_block_nnz, default_max_segments + else: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) vec_real = _triton_spsv_csr_vector vec_complex = _triton_spsv_csr_vector_complex diag_eps = _spsv_diag_eps_for_dtype(compute_dtype) diff --git a/tests/test_spsm.py b/tests/test_spsm.py index 5784b20..988cb5e 100644 --- a/tests/test_spsm.py +++ b/tests/test_spsm.py @@ -1,13 +1,21 @@ -"""SpSM tests: synthetic and optional .mtx CSV export (CSR/COO).""" +"""SpSM tests: synthetic triangular systems and optional .mtx batch CSV.""" import argparse import csv import glob import os +import sys +from pathlib import Path import torch +_PROJECT_ROOT = Path(__file__).resolve().parents[1] +_SRC_ROOT = _PROJECT_ROOT / "src" +if str(_SRC_ROOT) not in sys.path: + sys.path.insert(0, str(_SRC_ROOT)) + import flagsparse as fs +import flagsparse.sparse_operations.spsm as fs_spsm_impl try: import cupy as cp @@ -24,8 +32,10 @@ INDEX_DTYPES = [torch.int32] CSV_VALUE_DTYPES = [torch.float32, torch.float64] CSV_INDEX_DTYPES = [torch.int32] -WARMUP = 10 -ITERS = 50 +WARMUP = 5 +ITERS = 20 +SPSM_PYTORCH_DENSE_GPU_SAFETY_FACTOR = 3.0 +SPSM_OP_MODES = ["NON"] def _dtype_name(dtype): @@ -42,16 +52,36 @@ def _fmt_ms(v): return "N/A" if v is None else f"{v:.4f}" -def _fmt_speedup(other_ms, fs_ms): - if other_ms is None or fs_ms is None or fs_ms <= 0: - return "N/A" - return f"{other_ms / fs_ms:.2f}x" +def _fmt_ratio(v): + return "N/A" if v is None else f"{v:.2f}" def _fmt_err(v): return "N/A" if v is None else f"{v:.2e}" +def _safe_ratio(other_ms, triton_ms): + if other_ms is None or triton_ms is None or triton_ms <= 0: + return None + return other_ms / triton_ms + + +def _parse_csv_tokens(raw): + return [tok.strip() for tok in str(raw).split(",") if tok.strip()] + + +def _parse_ops_filter(raw): + tokens = [tok.strip().upper() for tok in _parse_csv_tokens(raw)] + if not tokens: + return ["NON"] + invalid = [tok for tok in tokens if tok not in SPSM_OP_MODES] + if invalid: + raise ValueError( + f"unsupported spsm ops: {invalid}; current SpSM test only supports NON/NON_TRANS" + ) + return tokens + + def _build_triangular_case(n=512, n_rhs=32, value_dtype=torch.float32): device = torch.device("cuda") A = torch.tril(torch.randn((n, n), dtype=value_dtype, device=device) * 0.02) @@ -72,14 +102,197 @@ def _build_triangular_case(n=512, n_rhs=32, value_dtype=torch.float32): return data, row, col, indptr, B, (n, n) -def _csr_to_coo(data, indices, indptr, shape): - n_rows = int(shape[0]) +def _csr_to_coo(indices, indptr, n_rows): + row = torch.repeat_interleave( + torch.arange(n_rows, device=indptr.device, dtype=torch.int64), + indptr[1:] - indptr[:-1], + ) + return row, indices.to(torch.int64) + + +def _csr_to_dense(data, indices, indptr, shape): + n_rows, n_cols = int(shape[0]), int(shape[1]) + if n_rows == 0 or n_cols == 0: + return torch.zeros((n_rows, n_cols), dtype=data.dtype, device=data.device) row = torch.repeat_interleave( torch.arange(n_rows, device=data.device, dtype=torch.int64), indptr[1:] - indptr[:-1], ) - col = indices.to(torch.int64) - return data, row, col + coo = torch.sparse_coo_tensor( + torch.stack([row, indices.to(torch.int64)]), + data, + shape, + device=data.device, + ).coalesce() + return coo.to_dense() + + +def _gpu_dense_ref_fits(shape, n_rhs, dtype): + if not torch.cuda.is_available(): + return False, "CUDA unavailable" + element_size = torch.empty((), dtype=dtype).element_size() + dense_bytes = int(shape[0]) * int(shape[1]) * element_size + rhs_bytes = int(shape[0]) * int(n_rhs) * element_size + estimated_bytes = int( + dense_bytes * SPSM_PYTORCH_DENSE_GPU_SAFETY_FACTOR + rhs_bytes * 3 + ) + try: + free_bytes, _ = torch.cuda.mem_get_info() + except Exception as exc: + return False, f"cannot query CUDA memory ({exc})" + if estimated_bytes > free_bytes: + dense_gib = dense_bytes / (1024 ** 3) + need_gib = estimated_bytes / (1024 ** 3) + free_gib = free_bytes / (1024 ** 3) + return ( + False, + "CUDA dense fallback too large " + f"(dense matrix {dense_gib:.1f} GiB, est {need_gib:.1f} GiB > {free_gib:.1f} GiB free)", + ) + return True, None + + +def _benchmark_pytorch_reference(data, indices, indptr, shape, B): + if not data.is_cuda: + return None, None, "unavailable", "PyTorch CUDA dense reference requires CUDA tensors" + fits, reason = _gpu_dense_ref_fits(shape, B.shape[1], B.dtype) + if not fits: + return None, None, "unavailable", reason + try: + A_dense = _csr_to_dense(data, indices, indptr, shape) + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + X_ref = torch.linalg.solve_triangular(A_dense, B, upper=False) + e1.record() + torch.cuda.synchronize() + ms = e0.elapsed_time(e1) + return X_ref.to(B.dtype), ms, "gpu_dense", None + except Exception as exc: + if "out of memory" in str(exc).lower() and torch.cuda.is_available(): + torch.cuda.empty_cache() + return None, None, "unavailable", f"CUDA dense reference unavailable ({exc})" + + +def _benchmark_cusparse_reference(data, row, col, indptr, B, shape, fmt): + if cp is None or cpx_sparse is None or cpx_spsolve_triangular is None: + return None, None, "cusparse unavailable" + try: + data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) + B_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(B.contiguous())) + if fmt == "coo": + row_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(row.contiguous())) + col_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.contiguous())) + A_cp = cpx_sparse.coo_matrix((data_cp, (row_cp, col_cp)), shape=shape) + else: + idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.contiguous())) + ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.contiguous())) + A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) + for _ in range(WARMUP): + _ = cpx_spsolve_triangular(A_cp, B_cp, lower=True, unit_diagonal=False) + cp.cuda.runtime.deviceSynchronize() + c0 = cp.cuda.Event() + c1 = cp.cuda.Event() + c0.record() + for _ in range(ITERS): + X_cp = cpx_spsolve_triangular(A_cp, B_cp, lower=True, unit_diagonal=False) + c1.record() + c1.synchronize() + ms = cp.cuda.get_elapsed_time(c0, c1) / ITERS + X_t = torch.utils.dlpack.from_dlpack(X_cp.toDlpack()).to(B.dtype) + return X_t, ms, None + except Exception as exc: + return None, None, str(exc) + + +def _apply_csr_to_dense_rhs(data, indices, indptr, X, shape): + n_rows = int(shape[0]) + row, col = _csr_to_coo(indices, indptr, n_rows) + out = torch.zeros((n_rows, X.shape[1]), dtype=X.dtype, device=X.device) + out.index_add_(0, row, data[:, None] * X[col]) + return out + + +def _solution_residual_metrics(data, indices, indptr, shape, X, B, value_dtype): + atol, rtol = _tol(value_dtype) + B_recon = _apply_csr_to_dense_rhs(data, indices, indptr, X, shape) + err = float(torch.max(torch.abs(B_recon - B)).item()) if B.numel() else 0.0 + ok = torch.allclose(B_recon, B, atol=atol, rtol=rtol) + return err, ok + + +def _benchmark_flagsparse(call): + X = None + for _ in range(WARMUP): + X = call() + torch.cuda.synchronize() + e0 = torch.cuda.Event(True) + e1 = torch.cuda.Event(True) + e0.record() + for _ in range(ITERS): + X = call() + e1.record() + torch.cuda.synchronize() + return X, e0.elapsed_time(e1) / ITERS + + +def _benchmark_flagsparse_spsm_csr_split(data, indices, indptr, B, shape): + analysis_ms = fs_spsm_impl._analyze_spsm_csr( + data, + indices, + indptr, + B, + shape, + lower=True, + unit_diagonal=False, + clear_cache=True, + return_time=True, + ) + X, solve_ms = _benchmark_flagsparse( + lambda: fs.flagsparse_spsm_csr( + data, + indices, + indptr, + B, + shape, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + ) + ) + return X, analysis_ms, solve_ms + + +def _benchmark_flagsparse_spsm_coo_split(data, row, col, B, shape): + analysis_ms = fs_spsm_impl._analyze_spsm_coo( + data, + row, + col, + B, + shape, + lower=True, + unit_diagonal=False, + clear_cache=True, + return_time=True, + ) + X, solve_ms = _benchmark_flagsparse( + lambda: fs.flagsparse_spsm_coo( + data, + row, + col, + B, + shape, + lower=True, + unit_diagonal=False, + opA="NON_TRANS", + opB="NON_TRANS", + major="row", + ) + ) + return X, analysis_ms, solve_ms def _load_mtx_to_csr_torch(file_path, dtype=torch.float32, device=None): @@ -140,7 +353,6 @@ def _accum(r, c, v): elif mm_symmetry == "skew-symmetric" and r != c: _accum(c, r, -v) - # Force lower-triangular + strong diagonal so triangular solve is well-defined. for r in range(n_rows): row = row_maps[r] lower_row = {} @@ -168,107 +380,89 @@ def _accum(r, c, v): return data, indices, indptr, (n_rows, n_cols) -def _cupy_spsm_ref(data, row, col, indptr, B, shape, fmt="csr"): - if cp is None or cpx_sparse is None or cpx_spsolve_triangular is None: - return None, None, "cupy/cusparse unavailable" - try: - data_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(data.contiguous())) - B_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(B.contiguous())) - if fmt == "coo": - row_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(row.to(torch.int64).contiguous())) - col_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.to(torch.int64).contiguous())) - A_cp = cpx_sparse.coo_matrix((data_cp, (row_cp, col_cp)), shape=shape) - else: - idx_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(col.to(torch.int64).contiguous())) - ptr_cp = cp.from_dlpack(torch.utils.dlpack.to_dlpack(indptr.contiguous())) - A_cp = cpx_sparse.csr_matrix((data_cp, idx_cp, ptr_cp), shape=shape) - for _ in range(max(0, int(WARMUP))): - _ = cpx_spsolve_triangular(A_cp, B_cp, lower=True, unit_diagonal=False) - cp.cuda.runtime.deviceSynchronize() - c0 = cp.cuda.Event() - c1 = cp.cuda.Event() - c0.record() - for _ in range(max(1, int(ITERS))): - X_cp = cpx_spsolve_triangular(A_cp, B_cp, lower=True, unit_diagonal=False) - c1.record() - c1.synchronize() - ms = cp.cuda.get_elapsed_time(c0, c1) / max(1, int(ITERS)) - X_t = torch.utils.dlpack.from_dlpack(X_cp.toDlpack()).to(B.dtype) - return X_t, ms, None - except Exception as exc: - return None, None, str(exc) - - def _run_one_spsm_case(data, indices, indptr, shape, value_dtype, index_dtype, n_rhs, fmt): - n_rows, _ = shape + n_rows = int(shape[0]) B = torch.randn((n_rows, n_rhs), dtype=value_dtype, device=data.device).contiguous() atol, rtol = _tol(value_dtype) + row, col = _csr_to_coo(indices, indptr, n_rows) if fmt == "csr": - X_fs, fs_ms = fs.flagsparse_spsm_csr( - data=data, - indices=indices.to(index_dtype), - indptr=indptr, - B=B, - shape=shape, - alpha=1.0, - lower=True, - unit_diagonal=False, - opA="NON_TRANS", - opB="NON_TRANS", - major="row", - return_time=True, + X_fs, analysis_ms, solve_ms = _benchmark_flagsparse_spsm_csr_split( + data, + indices.to(index_dtype), + indptr, + B, + shape, ) - data_c, row_c, col_c = _csr_to_coo(data, indices, indptr, shape) else: - data_c, row_c, col_c = _csr_to_coo(data, indices, indptr, shape) - X_fs, fs_ms = fs.flagsparse_spsm_coo( - data=data_c, - row=row_c.to(index_dtype), - col=col_c.to(index_dtype), - B=B, - shape=shape, - alpha=1.0, - lower=True, - unit_diagonal=False, - opA="NON_TRANS", - opB="NON_TRANS", - major="row", - return_time=True, + X_fs, analysis_ms, solve_ms = _benchmark_flagsparse_spsm_coo_split( + data, + row.to(index_dtype), + col.to(index_dtype), + B, + shape, ) + total_ms = analysis_ms + solve_ms if analysis_ms is not None and solve_ms is not None else None - X_cu, cu_ms, cu_reason = _cupy_spsm_ref( - data, row_c, col_c, indptr, B, shape, fmt=fmt + X_cu, cusparse_ms, cusparse_reason = _benchmark_cusparse_reference( + data, row, col, indptr, B, shape, fmt ) - ok_cu = None + X_pt, pytorch_ms, pt_backend, pytorch_reason = _benchmark_pytorch_reference( + data, indices, indptr, shape, B + ) + err_cu = None + ok_cu = None if X_cu is not None: - ok_cu = torch.allclose(X_fs, X_cu, atol=atol, rtol=rtol) err_cu = float(torch.max(torch.abs(X_fs - X_cu)).item()) if X_fs.numel() else 0.0 + ok_cu = torch.allclose(X_fs, X_cu, atol=atol, rtol=rtol) - if ok_cu is None: - status = "SKIP" - elif ok_cu: + err_pt = None + ok_pt = None + if X_pt is not None: + err_pt = float(torch.max(torch.abs(X_fs - X_pt)).item()) if X_fs.numel() else 0.0 + ok_pt = torch.allclose(X_fs, X_pt, atol=atol, rtol=rtol) + + err_res, ok_res = _solution_residual_metrics(data, indices, indptr, shape, X_fs, B, value_dtype) + ref_errors = [v for v in (err_pt, err_cu) if v is not None] + err_ref = min(ref_errors) if ref_errors else None + ref_ok = False + if ok_pt is not None: + ref_ok = ref_ok or ok_pt + if ok_cu is not None: + ref_ok = ref_ok or ok_cu + + if ref_ok: status = "PASS" + elif X_pt is None and X_cu is None: + status = "REF_FAIL" else: status = "FAIL" - - note_parts = [] - if cu_reason: - note_parts.append(cu_reason) + if not ok_res and status == "PASS": + status = "FAIL" return { - "fmt": fmt, - "n_rows": int(shape[0]), + "format": fmt, + "n_rows": n_rows, "n_cols": int(shape[1]), "nnz": int(data.numel()), - "rhs": int(n_rhs), - "flagsparse_ms": fs_ms, - "cusparse_ms": cu_ms, - "fs_vs_cu": _fmt_speedup(cu_ms, fs_ms), + "n_rhs": int(n_rhs), + "triton_analysis_ms": analysis_ms, + "triton_solve_ms": solve_ms, + "triton_time_total_ms": total_ms, + "cusparse_solve_ms": cusparse_ms, + "pytorch_solve_ms": pytorch_ms, + "pytorch_backend": pt_backend, + "cusparse/triton": _safe_ratio(cusparse_ms, solve_ms), + "pytorch/triton": _safe_ratio(pytorch_ms, solve_ms), "status": status, + "err_ref": err_ref, + "err_res": err_res, + "err_pt": err_pt, "err_cu": err_cu, - "note": " | ".join(note_parts), + "pytorch_reason": pytorch_reason, + "cusparse_reason": cusparse_reason, + "error": None, } @@ -278,83 +472,61 @@ def run_spsm_synthetic_all(n=512, n_rhs=32): return total = 0 failed = 0 - print("=" * 120) - print("FLAGSPARSE SpSM (NON_TRANS/NON_TRANS, row-major) synthetic test") - print("=" * 120) + print("=" * 160) + print("FLAGSPARSE SpSM synthetic test") + print("=" * 160) + print( + "PyTorch(ms)=CUDA reference (dense triangular solve). " + "FlagSparse analysis is measured separately; FlagSparse(ms) below reports solve only." + ) print( f"{'Fmt':>5} {'dtype':>9} {'index':>7} {'N':>6} {'RHS':>6} {'NNZ':>10} " - f"{'FS(ms)':>10} {'CU(ms)':>10} {'FS/CU':>8} " - f"{'Status':>8} {'Err(CU)':>12}" + f"{'FS.anlys':>10} {'FS.solve':>10} {'FS.total':>10} " + f"{'cu.solve':>10} {'pt.solve':>10} {'cu/triton':>10} {'pt/triton':>10} " + f"{'Status':>10} {'Err(Ref)':>12} {'Err(Res)':>12} {'Err(PT)':>12} {'Err(CU)':>12}" ) - print("-" * 120) + print("-" * 160) + for fmt in FORMATS: for value_dtype in VALUE_DTYPES: for index_dtype in INDEX_DTYPES: - data, row, col, indptr, B, shape = _build_triangular_case( - n=n, n_rhs=n_rhs, value_dtype=value_dtype + data, row, col, indptr, _B, shape = _build_triangular_case( + n=n, + n_rhs=n_rhs, + value_dtype=value_dtype, ) - atol, rtol = _tol(value_dtype) - if fmt == "csr": - X_fs, fs_ms = fs.flagsparse_spsm_csr( - data=data, - indices=col.to(index_dtype), - indptr=indptr, - B=B, - shape=shape, - alpha=1.0, - lower=True, - unit_diagonal=False, - opA="NON_TRANS", - opB="NON_TRANS", - major="row", - return_time=True, - ) - else: - X_fs, fs_ms = fs.flagsparse_spsm_coo( - data=data, - row=row.to(index_dtype), - col=col.to(index_dtype), - B=B, - shape=shape, - alpha=1.0, - lower=True, - unit_diagonal=False, - opA="NON_TRANS", - opB="NON_TRANS", - major="row", - return_time=True, - ) - - X_cu, cu_ms, cu_reason = _cupy_spsm_ref( - data, row, col, indptr, B, shape, fmt=fmt + one = _run_one_spsm_case( + data, + col, + indptr, + shape, + value_dtype, + index_dtype, + n_rhs, + fmt, ) - ok_cu = None - err_cu = None - if X_cu is not None: - ok_cu = torch.allclose(X_fs, X_cu, atol=atol, rtol=rtol) - err_cu = float(torch.max(torch.abs(X_fs - X_cu)).item()) if X_fs.numel() else 0.0 - - if ok_cu is None: - status = "SKIP" - elif ok_cu: - status = "PASS" - else: - status = "FAIL" total += 1 - if status != "PASS": + if one["status"] != "PASS": failed += 1 - print( f"{fmt:>5} {_dtype_name(value_dtype):>9} {_dtype_name(index_dtype):>7} " - f"{shape[0]:>6} {B.shape[1]:>6} {int(data.numel()):>10} " - f"{_fmt_ms(fs_ms):>10} {_fmt_ms(cu_ms):>10} {_fmt_speedup(cu_ms, fs_ms):>8} " - f"{status:>8} {_fmt_err(err_cu):>12}" + f"{shape[0]:>6} {n_rhs:>6} {one['nnz']:>10} " + f"{_fmt_ms(one['triton_analysis_ms']):>10} {_fmt_ms(one['triton_solve_ms']):>10} {_fmt_ms(one['triton_time_total_ms']):>10} " + f"{_fmt_ms(one['cusparse_solve_ms']):>10} {_fmt_ms(one['pytorch_solve_ms']):>10} " + f"{_fmt_ratio(one['cusparse/triton']):>10} {_fmt_ratio(one['pytorch/triton']):>10} " + f"{one['status']:>10} {_fmt_err(one['err_ref']):>12} {_fmt_err(one['err_res']):>12} " + f"{_fmt_err(one['err_pt']):>12} {_fmt_err(one['err_cu']):>12}" ) - if cu_reason is not None: - print(f" NOTE: {cu_reason}") - print("-" * 120) + if one["status"] in ("FAIL", "REF_FAIL"): + if one["pytorch_backend"] and one["pytorch_backend"] != "gpu_dense": + print(f" NOTE: pt_backend={one['pytorch_backend']}") + if one["pytorch_reason"]: + print(f" NOTE: {one['pytorch_reason']}") + if one["cusparse_reason"]: + print(f" NOTE: {one['cusparse_reason']}") + print("-" * 160) print(f"Total cases: {total} Failed: {failed}") - print("=" * 120) + print("=" * 160) def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): @@ -365,15 +537,19 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): rows_out = [] fmt = "coo" if use_coo else "csr" - print("=" * 132) - print(f"FLAGSPARSE SpSM .mtx batch ({fmt.upper()}), NON_TRANS/NON_TRANS, row-major") - print("=" * 132) + print("=" * 176) + print( + f"FLAGSPARSE SpSM .mtx batch ({fmt.upper()}) | " + "PyTorch(ms)=CUDA reference (dense triangular solve)" + ) + print("=" * 176) print( f"{'Matrix':<28} {'dtype':>9} {'index':>7} {'N':>7} {'RHS':>6} {'NNZ':>10} " - f"{'FS(ms)':>10} {'CU(ms)':>10} {'FS/CU':>8} " - f"{'Status':>8} {'Err(CU)':>12}" + f"{'FS.anlys':>10} {'FS.solve':>10} {'FS.total':>10} " + f"{'cu.solve':>10} {'pt.solve':>10} {'cu/triton':>10} {'pt/triton':>10} " + f"{'Status':>10} {'Err(Ref)':>12} {'Err(Res)':>12} {'Err(PT)':>12} {'Err(CU)':>12}" ) - print("-" * 132) + print("-" * 176) for value_dtype in CSV_VALUE_DTYPES: for index_dtype in CSV_INDEX_DTYPES: @@ -385,9 +561,11 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): } try: data, indices, indptr, shape = _load_mtx_to_csr_torch( - path, dtype=value_dtype, device=device + path, + dtype=value_dtype, + device=device, ) - one = _run_one_spsm_case( + row = _run_one_spsm_case( data, indices, indptr, @@ -397,102 +575,123 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): n_rhs, fmt, ) - row = { - **base, - **one, - } + row = {**base, **row} rows_out.append(row) short = base["matrix"][:27] + ("…" if len(base["matrix"]) > 27 else "") print( f"{short:<28} {base['value_dtype']:>9} {base['index_dtype']:>7} " - f"{row['n_rows']:>7} {row['rhs']:>6} {row['nnz']:>10} " - f"{_fmt_ms(row['flagsparse_ms']):>10} {_fmt_ms(row['cusparse_ms']):>10} " - f"{row['fs_vs_cu']:>8} {row['status']:>8} " - f"{_fmt_err(row['err_cu']):>12}" + f"{row['n_rows']:>7} {row['n_rhs']:>6} {row['nnz']:>10} " + f"{_fmt_ms(row['triton_analysis_ms']):>10} {_fmt_ms(row['triton_solve_ms']):>10} {_fmt_ms(row['triton_time_total_ms']):>10} " + f"{_fmt_ms(row['cusparse_solve_ms']):>10} {_fmt_ms(row['pytorch_solve_ms']):>10} " + f"{_fmt_ratio(row['cusparse/triton']):>10} {_fmt_ratio(row['pytorch/triton']):>10} " + f"{row['status']:>10} {_fmt_err(row['err_ref']):>12} {_fmt_err(row['err_res']):>12} " + f"{_fmt_err(row['err_pt']):>12} {_fmt_err(row['err_cu']):>12}" ) - if row.get("note"): - print(f" NOTE: {row['note']}") + if row["status"] in ("FAIL", "REF_FAIL"): + if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_dense": + print(f" NOTE: pt_backend={row['pytorch_backend']}") + if row["pytorch_reason"]: + print(f" NOTE: {row['pytorch_reason']}") + if row["cusparse_reason"]: + print(f" NOTE: {row['cusparse_reason']}") except Exception as exc: row = { **base, - "fmt": fmt, + "format": fmt, "n_rows": "ERR", "n_cols": "ERR", "nnz": "ERR", - "rhs": int(n_rhs), - "flagsparse_ms": None, - "cusparse_ms": None, - "fs_vs_cu": "N/A", + "n_rhs": int(n_rhs), + "triton_analysis_ms": None, + "triton_solve_ms": None, + "triton_time_total_ms": None, + "cusparse_solve_ms": None, + "pytorch_solve_ms": None, + "pytorch_backend": None, + "cusparse/triton": None, + "pytorch/triton": None, "status": "ERROR", + "err_ref": None, + "err_res": None, + "err_pt": None, "err_cu": None, - "note": str(exc), + "pytorch_reason": None, + "cusparse_reason": None, + "error": str(exc), } rows_out.append(row) short = base["matrix"][:27] + ("…" if len(base["matrix"]) > 27 else "") print( f"{short:<28} {base['value_dtype']:>9} {base['index_dtype']:>7} " f"{'ERR':>7} {int(n_rhs):>6} {'ERR':>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} " - f"{'N/A':>8} {'ERROR':>8} " - f"{_fmt_err(None):>12}" + f"{'N/A':>10} {'N/A':>10} {'ERROR':>10} " + f"{_fmt_err(None):>12} {_fmt_err(None):>12} {_fmt_err(None):>12} {_fmt_err(None):>12}" ) print(f" ERROR: {exc}") - print("-" * 132) + print("-" * 176) fieldnames = [ "matrix", "value_dtype", "index_dtype", - "fmt", + "format", "n_rows", "n_cols", "nnz", - "rhs", - "flagsparse_ms", - "cusparse_ms", - "fs_vs_cu", + "n_rhs", + "triton_analysis_ms", + "triton_solve_ms", + "triton_time_total_ms", + "cusparse_solve_ms", + "pytorch_solve_ms", + "pytorch_backend", + "cusparse/triton", + "pytorch/triton", "status", + "err_ref", + "err_res", + "err_pt", "err_cu", - "note", + "pytorch_reason", + "cusparse_reason", + "error", ] with open(csv_path, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() - for r in rows_out: - w.writerow(r) + for row in rows_out: + w.writerow(row) print(f"Wrote {len(rows_out)} rows to {csv_path}") def main(): parser = argparse.ArgumentParser( - description="SpSM test: synthetic and optional .mtx CSV export (CSR/COO)." + description="SpSM test: synthetic triangular systems and optional .mtx batch CSV." ) parser.add_argument( "mtx", nargs="*", help=".mtx file path(s), or directory(ies) to glob for *.mtx", ) - parser.add_argument( - "--synthetic", action="store_true", help="Run synthetic triangular tests" - ) + parser.add_argument("--synthetic", action="store_true", help="Run synthetic triangular tests") parser.add_argument("--n", type=int, default=512, help="matrix size (synthetic)") parser.add_argument("--rhs", type=int, default=32, help="number of RHS columns") + parser.add_argument("--csv-csr", type=str, default=None, metavar="FILE") + parser.add_argument("--csv-coo", type=str, default=None, metavar="FILE") parser.add_argument( - "--csv-csr", + "--ops", type=str, - default=None, - metavar="FILE", - help="Run .mtx batch in CSR mode and export CSV", - ) - parser.add_argument( - "--csv-coo", - type=str, - default=None, - metavar="FILE", - help="Run .mtx batch in COO mode and export CSV", + default="NON", + help="comma-separated op(A) modes; currently only NON/NON_TRANS is supported", ) args = parser.parse_args() + ops = _parse_ops_filter(args.ops) + if any(op != "NON" for op in ops): + raise ValueError("SpSM test currently supports only --ops NON") + if args.synthetic: run_spsm_synthetic_all(n=args.n, n_rhs=args.rhs) return @@ -526,4 +725,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From dbdb9d077677c0aba4a3944e75c3972e257cadb2 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Mon, 27 Apr 2026 19:18:51 +0800 Subject: [PATCH 17/22] spsm-opt --- tests/test_spsm.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/test_spsm.py b/tests/test_spsm.py index 988cb5e..78ff983 100644 --- a/tests/test_spsm.py +++ b/tests/test_spsm.py @@ -35,7 +35,7 @@ WARMUP = 5 ITERS = 20 SPSM_PYTORCH_DENSE_GPU_SAFETY_FACTOR = 3.0 -SPSM_OP_MODES = ["NON"] +SPSM_OP_MODES = ["NON", "NON_TRANS"] def _dtype_name(dtype): @@ -79,7 +79,10 @@ def _parse_ops_filter(raw): raise ValueError( f"unsupported spsm ops: {invalid}; current SpSM test only supports NON/NON_TRANS" ) - return tokens + normalized = [] + for tok in tokens: + normalized.append("NON" if tok == "NON_TRANS" else tok) + return normalized def _build_triangular_case(n=512, n_rhs=32, value_dtype=torch.float32): @@ -560,6 +563,11 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): "index_dtype": _dtype_name(index_dtype), } try: + print( + f"RUNNING: {base['matrix']} | dtype={base['value_dtype']} | " + f"index={base['index_dtype']} | fmt={fmt}", + flush=True, + ) data, indices, indptr, shape = _load_mtx_to_csr_torch( path, dtype=value_dtype, @@ -595,6 +603,8 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): if row["cusparse_reason"]: print(f" NOTE: {row['cusparse_reason']}") except Exception as exc: + err_msg = str(exc) + status = "SKIP" if "SpSM requires square matrices" in err_msg else "ERROR" row = { **base, "format": fmt, @@ -610,14 +620,14 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): "pytorch_backend": None, "cusparse/triton": None, "pytorch/triton": None, - "status": "ERROR", + "status": status, "err_ref": None, "err_res": None, "err_pt": None, "err_cu": None, "pytorch_reason": None, "cusparse_reason": None, - "error": str(exc), + "error": err_msg, } rows_out.append(row) short = base["matrix"][:27] + ("…" if len(base["matrix"]) > 27 else "") @@ -626,10 +636,10 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): f"{'ERR':>7} {int(n_rhs):>6} {'ERR':>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} " - f"{'N/A':>10} {'N/A':>10} {'ERROR':>10} " + f"{'N/A':>10} {'N/A':>10} {status:>10} " f"{_fmt_err(None):>12} {_fmt_err(None):>12} {_fmt_err(None):>12} {_fmt_err(None):>12}" ) - print(f" ERROR: {exc}") + print(f" {status}: {exc}") print("-" * 176) fieldnames = [ @@ -690,7 +700,7 @@ def main(): ops = _parse_ops_filter(args.ops) if any(op != "NON" for op in ops): - raise ValueError("SpSM test currently supports only --ops NON") + raise ValueError("SpSM test currently supports only --ops NON/NON_TRANS") if args.synthetic: run_spsm_synthetic_all(n=args.n, n_rhs=args.rhs) From 852e38c4cb89a6c73826037bcf267f91ac814576 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 28 Apr 2026 12:04:07 +0800 Subject: [PATCH 18/22] spsv&spsm_opt --- src/flagsparse/__init__.py | 2 +- src/flagsparse/sparse_operations/__init__.py | 2 +- src/flagsparse/sparse_operations/spsm.py | 45 +- src/flagsparse/sparse_operations/spsv.py | 1131 +++++++++++++++++- 4 files changed, 1128 insertions(+), 52 deletions(-) diff --git a/src/flagsparse/__init__.py b/src/flagsparse/__init__.py index 02b5d76..8b6e245 100644 --- a/src/flagsparse/__init__.py +++ b/src/flagsparse/__init__.py @@ -107,8 +107,8 @@ "benchmark_spgemm_case", "benchmark_sddmm_case", "comprehensive_spmm_test", - "benchmark_spmv_case", "comprehensive_spsm_test", + "benchmark_spmv_case", } _FORMAT_EXPORTS = { diff --git a/src/flagsparse/sparse_operations/__init__.py b/src/flagsparse/sparse_operations/__init__.py index 0fac868..1614179 100644 --- a/src/flagsparse/sparse_operations/__init__.py +++ b/src/flagsparse/sparse_operations/__init__.py @@ -1,6 +1,6 @@ """FlagSparse sparse operations (gather, scatter, SpMV, SpMM, SpGEMM, SDDMM, SpSM).""" -from ._common import SUPPORTED_INDEX_DTYPES, SUPPORTED_VALUE_DTYPES, cp, cpx_sparse +from ._common import SUPPORTED_INDEX_DTYPES, SUPPORTED_VALUE_DTYPES from .benchmarks import ( benchmark_gather_case, benchmark_performance, diff --git a/src/flagsparse/sparse_operations/spsm.py b/src/flagsparse/sparse_operations/spsm.py index 8254a4c..fc5d805 100644 --- a/src/flagsparse/sparse_operations/spsm.py +++ b/src/flagsparse/sparse_operations/spsm.py @@ -368,43 +368,6 @@ def _build_spsm_levels(indptr, indices, n_rows, lower=True): return [torch.tensor(rows, dtype=torch.int32, device=device) for rows in buckets if rows] -def _build_spsm_frontiers(indptr, indices, levels, lower=True): - if not levels: - return [] - - indptr_h = indptr.to(torch.int64).cpu() - indices_h = indices.to(torch.int64).cpu() - device = indptr.device - frontier_rows = [] - frontier_row_set = set() - merged = [] - - def _flush_frontier(): - nonlocal frontier_rows, frontier_row_set - if frontier_rows: - merged.append(torch.tensor(frontier_rows, dtype=torch.int32, device=device)) - frontier_rows = [] - frontier_row_set = set() - - for rows_lv in levels: - for row in rows_lv.to(torch.int64).cpu().tolist(): - start = int(indptr_h[row].item()) - end = int(indptr_h[row + 1].item()) - depends_on_frontier = False - for p in range(start, end): - col = int(indices_h[p].item()) - is_dep = col < row if lower else col > row - if is_dep and col in frontier_row_set: - depends_on_frontier = True - break - if depends_on_frontier: - _flush_frontier() - frontier_rows.append(int(row)) - frontier_row_set.add(int(row)) - _flush_frontier() - return merged - - def _auto_spsm_launch_config(indptr, block_nnz=None, max_segments=None): if indptr.numel() <= 1: max_nnz_per_row = 0 @@ -493,7 +456,7 @@ def _prepare_spsm_csr_system(data, indices64, indptr64, n_rows, lower): "kernel_indices64": indices64, "kernel_indices32": indices64.to(torch.int32), "kernel_indptr64": indptr64, - "launch_groups": _build_spsm_frontiers(indptr64, indices64, levels, lower=lower), + "launch_groups": levels, "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, "lower_eff": bool(lower), @@ -509,7 +472,7 @@ def _prepare_spsm_coo_system(data, row64, col64, n_rows, n_cols, lower): "kernel_cols64": col_u64, "kernel_cols32": col_u64.to(torch.int32), "kernel_row_ptr64": row_ptr, - "launch_groups": _build_spsm_frontiers(row_ptr, col_u64, levels, lower=lower), + "launch_groups": levels, "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, "lower_eff": bool(lower), @@ -580,7 +543,7 @@ def _run_spsm_csr_core( if launch_groups is None: levels = _build_spsm_levels(indptr64, indices32, n_rows, lower=lower) - launch_groups = _build_spsm_frontiers(indptr64, indices32, levels, lower=lower) + launch_groups = levels if block_nnz_use is None or max_segments_use is None: block_nnz_use, max_segments_use = _auto_spsm_launch_config( indptr64, block_nnz=block_nnz, max_segments=max_segments @@ -647,7 +610,7 @@ def _run_spsm_coo_core( if launch_groups is None: levels = _build_spsm_levels(row_ptr64, cols32, n_rows, lower=lower) - launch_groups = _build_spsm_frontiers(row_ptr64, cols32, levels, lower=lower) + launch_groups = levels if block_nnz_use is None or max_segments_use is None: block_nnz_use, max_segments_use = _auto_spsm_launch_config( row_ptr64, block_nnz=block_nnz, max_segments=max_segments diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index c89c4e6..203e78f 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -186,11 +186,12 @@ def _spsv_cache_put(cache, key, value, max_entries): cache.popitem(last=False) -def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode): +def _csr_preprocess_cache_key(data, indices, indptr, shape, lower, trans_mode, unit_diagonal): return ( "csr_preprocess", trans_mode, bool(lower), + bool(unit_diagonal), int(shape[0]), int(shape[1]), _tensor_cache_token(data), @@ -294,11 +295,233 @@ def _flush_frontier(): return merged -def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode): +def _prepare_spsv_transpose_cw_metadata(data, indices64, indptr64, n_rows, lower, unit_diagonal=False): + indegree = torch.zeros(n_rows, dtype=torch.int32, device=data.device) + diag = torch.ones(n_rows, dtype=data.dtype, device=data.device) + if n_rows == 0: + return diag, indegree + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + dep_mask = indices64 < row_ids if lower else indices64 > row_ids + if dep_mask.numel() > 0: + dep_rows = row_ids[dep_mask] + dep_counts = torch.bincount(dep_rows, minlength=n_rows) + indegree.copy_(dep_counts.to(torch.int32)) + if not unit_diagonal: + indegree.add_(1) + diag_mask = indices64 == row_ids + if bool(torch.any(diag_mask).item()): + diag.scatter_(0, row_ids[diag_mask], data[diag_mask]) + return diag, indegree + + +def _sort_csr_rows(data, indices64, indptr64, n_rows, n_cols, lower=True): + if data.numel() == 0: + return data, indices64, indptr64 + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + key = row_ids * max(1, n_cols) + if lower: + key = key + indices64 + else: + key = key + (n_cols - 1 - indices64) + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + return data[order], indices64[order], indptr64 + + +def _prepare_spsv_nontrans_cw_metadata(data, indices64, indptr64, n_rows, lower, unit_diagonal=False): + diag = torch.ones(n_rows, dtype=data.dtype, device=data.device) + if n_rows == 0: + return diag + if unit_diagonal: + return diag + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + diag_mask = indices64 == row_ids + if bool(torch.any(diag_mask).item()): + diag.scatter_(0, row_ids[diag_mask], data[diag_mask]) + return diag + + +def _cw_worker_count(n_rows, max_frontier, avg_nnz_per_row, n_rhs): + if n_rows <= 0: + return 1 + target = max(256, min(n_rows, 4096)) + if max_frontier > 0: + target = min(target, max(256, min(n_rows, max_frontier * 8))) + if avg_nnz_per_row > 2048: + target = max(128, target // 2) + if n_rhs >= 16: + target = max(128, target // 2) + return int(max(1, min(n_rows, target))) + + +def _spsv_level_stats(levels, n_rows): + if not levels: + return { + "num_levels": 0, + "max_frontier": 0, + "avg_frontier": 0.0, + "frontier_ratio": 0.0, + } + num_levels = len(levels) + max_frontier = max(int(rows.numel()) for rows in levels) + avg_frontier = float(n_rows) / float(max(num_levels, 1)) + frontier_ratio = float(max_frontier) / float(max(n_rows, 1)) + return { + "num_levels": num_levels, + "max_frontier": max_frontier, + "avg_frontier": avg_frontier, + "frontier_ratio": frontier_ratio, + } + + +def _build_spsv_matrix_stats(indptr64, levels, n_rows): + stats = _spsv_level_stats(levels, n_rows) + if indptr64.numel() <= 1: + avg_nnz_per_row = 0.0 + max_nnz_per_row = 0 + else: + row_lengths = indptr64[1:] - indptr64[:-1] + avg_nnz_per_row = float(row_lengths.to(torch.float32).mean().item()) + max_nnz_per_row = int(row_lengths.max().item()) + stats["avg_nnz_per_row"] = avg_nnz_per_row + stats["max_nnz_per_row"] = max_nnz_per_row + stats["n_rows"] = int(n_rows) + return stats + + +def _choose_spsv_block_rhs(n_rhs, matrix_stats, complex_mode=False): + avg_nnz = float(matrix_stats.get("avg_nnz_per_row", 0.0)) + max_nnz = int(matrix_stats.get("max_nnz_per_row", 0)) + frontier_ratio = float(matrix_stats.get("frontier_ratio", 0.0)) + if n_rhs <= 1: + return 1 + if complex_mode: + if avg_nnz > 512 or max_nnz > 4096: + return 4 if n_rhs >= 4 else n_rhs + return 8 if n_rhs >= 8 else n_rhs + if avg_nnz > 2048 or max_nnz > 16384: + return 4 if n_rhs >= 4 else n_rhs + if avg_nnz > 512 or frontier_ratio < 0.02: + return 8 if n_rhs >= 8 else n_rhs + if avg_nnz > 128: + return 16 if n_rhs >= 16 else n_rhs + if n_rhs <= 8: + return n_rhs + if n_rhs <= 16: + return 16 + if n_rhs <= 32: + return 32 + return 64 + + +def _score_nontrans_levels(matrix_stats, n_rhs, complex_mode): + score = 0.0 + score += float(matrix_stats["frontier_ratio"]) * 10.0 + score += min(float(matrix_stats["avg_frontier"]) / 64.0, 6.0) + score += min(float(matrix_stats["avg_nnz_per_row"]) / 256.0, 4.0) + if n_rhs >= 8: + score += 2.5 + if complex_mode: + score += 2.0 + return score + + +def _score_nontrans_cw(matrix_stats, n_rhs, complex_mode): + score = 0.0 + if matrix_stats["num_levels"] > 1024: + score += 2.0 + if matrix_stats["num_levels"] > 4096: + score += 2.0 + if matrix_stats["frontier_ratio"] < 0.03: + score += 4.0 + elif matrix_stats["frontier_ratio"] < 0.06: + score += 2.0 + if matrix_stats["avg_frontier"] < 16.0: + score += 3.0 + elif matrix_stats["avg_frontier"] < 32.0: + score += 1.5 + if matrix_stats["avg_nnz_per_row"] <= 128.0: + score += 1.5 + if matrix_stats["max_nnz_per_row"] > 4096: + score -= 2.5 + if n_rhs >= 8: + score -= 3.0 + elif n_rhs >= 4: + score -= 1.5 + if complex_mode: + score -= 2.0 + return score + + +def _score_transpose_push(matrix_stats, n_rhs, complex_mode): + score = 0.0 + score += float(matrix_stats["frontier_ratio"]) * 12.0 + score += min(float(matrix_stats["avg_frontier"]) / 48.0, 6.0) + score += min(float(matrix_stats["avg_nnz_per_row"]) / 256.0, 4.0) + if n_rhs >= 4: + score += 2.0 + if complex_mode: + score += 2.5 + return score + + +def _score_transpose_cw(matrix_stats, n_rhs, complex_mode): + score = 0.0 + if matrix_stats["num_levels"] > 2048: + score += 2.0 + if matrix_stats["num_levels"] > 8192: + score += 2.0 + if matrix_stats["frontier_ratio"] < 0.02: + score += 4.0 + elif matrix_stats["frontier_ratio"] < 0.05: + score += 2.0 + if matrix_stats["avg_frontier"] < 12.0: + score += 3.0 + elif matrix_stats["avg_frontier"] < 24.0: + score += 1.5 + if matrix_stats["avg_nnz_per_row"] <= 96.0: + score += 1.0 + if matrix_stats["max_nnz_per_row"] > 2048: + score -= 2.5 + if n_rhs >= 4: + score -= 2.5 + elif n_rhs >= 2: + score -= 1.0 + if complex_mode: + score -= 2.5 + return score + + +def _select_nontrans_route(matrix_stats, n_rhs, complex_mode): + levels_score = _score_nontrans_levels(matrix_stats, n_rhs, complex_mode) + cw_score = _score_nontrans_cw(matrix_stats, n_rhs, complex_mode) + return "csr_cw" if cw_score > levels_score else "csr_levels" + + +def _select_transpose_route(matrix_stats, n_rhs, complex_mode): + push_score = _score_transpose_push(matrix_stats, n_rhs, complex_mode) + cw_score = _score_transpose_cw(matrix_stats, n_rhs, complex_mode) + return "transpose_cw" if cw_score > push_score else "transpose_push" + + +def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode, unit_diagonal): + complex_mode = torch.is_complex(data) if trans_mode == "N": levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) + matrix_stats = _build_spsv_matrix_stats(indptr64, levels, n_rows) default_block_nnz, default_max_segments = _auto_spsv_launch_config(indptr64) - return { + levels_plan = { "solve_kind": "csr_levels", "kernel_data": data, "kernel_indices64": indices64, @@ -311,13 +534,49 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, "transpose_conjugate": False, + "cw_worker_count": None, + "matrix_stats": matrix_stats, + "route_name": "csr_levels", + "alt_plan": None, } + data_sorted, indices_sorted64, indptr_sorted64 = _sort_csr_rows( + data, indices64, indptr64, n_rows, n_cols, lower=lower + ) + default_block_nnz, default_max_segments = _auto_spsv_launch_config(indptr_sorted64) + cw_plan = { + "solve_kind": "csr_cw", + "kernel_data": data_sorted, + "kernel_indices64": indices_sorted64, + "kernel_indices32": indices_sorted64.to(torch.int32), + "kernel_indptr64": indptr_sorted64, + "lower_eff": lower, + "launch_groups": None, + "default_block_nnz": default_block_nnz, + "default_max_segments": default_max_segments, + "transpose_conjugate": False, + "cw_diag": _prepare_spsv_nontrans_cw_metadata( + data_sorted, indices_sorted64, indptr_sorted64, n_rows, lower, unit_diagonal=unit_diagonal + ), + "cw_worker_count": _cw_worker_count( + n_rows, matrix_stats["max_frontier"], matrix_stats["avg_nnz_per_row"], 1 + ), + "matrix_stats": matrix_stats, + "route_name": "csr_cw", + "alt_plan": None, + } + preferred_kind = _select_nontrans_route(matrix_stats, 1, complex_mode) + if preferred_kind == "csr_cw": + cw_plan["alt_plan"] = levels_plan + return cw_plan + levels_plan["alt_plan"] = cw_plan + return levels_plan levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) + matrix_stats = _build_spsv_matrix_stats(indptr64, levels, n_rows) default_block_nnz, default_max_segments = _choose_transpose_family_launch_config( indptr64 ) - return { + push_plan = { "solve_kind": "transpose_push", "kernel_data": data, "kernel_indices64": indices64, @@ -330,7 +589,44 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, "transpose_conjugate": trans_mode == "C", + "cw_worker_count": None, + "matrix_stats": matrix_stats, + "route_name": "transpose_push", + "alt_plan": None, + } + + diag, indegree_init = _prepare_spsv_transpose_cw_metadata( + data, indices64, indptr64, n_rows, lower, unit_diagonal=unit_diagonal + ) + default_block_nnz, default_max_segments = _choose_transpose_family_launch_config( + indptr64 + ) + cw_plan = { + "solve_kind": "transpose_cw", + "kernel_data": data, + "kernel_indices64": indices64, + "kernel_indices32": indices64.to(torch.int32), + "kernel_indptr64": indptr64, + "lower_eff": lower, + "launch_groups": None, + "default_block_nnz": default_block_nnz, + "default_max_segments": default_max_segments, + "transpose_conjugate": trans_mode == "C", + "transpose_diag": diag, + "transpose_indegree_init": indegree_init, + "cw_worker_count": _cw_worker_count( + n_rows, matrix_stats["max_frontier"], matrix_stats["avg_nnz_per_row"], 1 + ), + "matrix_stats": matrix_stats, + "route_name": "transpose_cw", + "alt_plan": None, } + preferred_kind = _select_transpose_route(matrix_stats, 1, complex_mode) + if preferred_kind == "transpose_cw": + cw_plan["alt_plan"] = push_plan + return cw_plan + push_plan["alt_plan"] = cw_plan + return push_plan def _resolve_spsv_csr_runtime( @@ -341,6 +637,7 @@ def _resolve_spsv_csr_runtime( shape, lower, transpose, + unit_diagonal=False, ): input_data = data input_indices = indices @@ -358,7 +655,7 @@ def _resolve_spsv_csr_runtime( _validate_spsv_trans_combo(data.dtype, input_index_dtype, "CSR") preprocess_key = _csr_preprocess_cache_key( - input_data, input_indices, input_indptr, (n_rows, n_cols), lower, trans_mode + input_data, input_indices, input_indptr, (n_rows, n_cols), lower, trans_mode, unit_diagonal ) cached = _spsv_cache_get(_SPSV_CSR_PREPROCESS_CACHE, preprocess_key) if cached is None: @@ -370,6 +667,7 @@ def _resolve_spsv_csr_runtime( n_cols, lower, trans_mode, + unit_diagonal, ) _spsv_cache_put( _SPSV_CSR_PREPROCESS_CACHE, @@ -388,6 +686,24 @@ def _resolve_spsv_csr_runtime( ) +def _select_spsv_runtime_plan(solve_plan, rhs_cols, compute_dtype, trans_mode): + matrix_stats = solve_plan.get("matrix_stats", {}) + route_name = solve_plan.get("route_name", solve_plan["solve_kind"]) + alt_plan = solve_plan.get("alt_plan") + complex_mode = compute_dtype in (torch.complex64, torch.complex128) + if trans_mode == "N": + desired = _select_nontrans_route(matrix_stats, rhs_cols, complex_mode) + else: + desired = _select_transpose_route(matrix_stats, rhs_cols, complex_mode) + if complex_mode and rhs_cols >= 2: + desired = "transpose_push" + if desired == route_name or alt_plan is None: + return solve_plan + if alt_plan.get("route_name", alt_plan["solve_kind"]) == desired: + return alt_plan + return solve_plan + + @triton.jit def _spsv_csr_level_kernel( data_ptr, @@ -538,6 +854,168 @@ def _spsv_csr_level_kernel_complex( tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) +@triton.jit +def _spsv_csr_cw_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + diag_ptr, + b_ptr, + x_ptr, + ready_ptr, + row_counter_ptr, + n_rows, + n_rhs, + stride_b0, + stride_x0, + BLOCK_RHS: tl.constexpr, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + row = tl.atomic_add(row_counter_ptr, 1) + while row < n_rows: + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + diag = tl.load(diag_ptr + row) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + + for rhs_base in range(0, n_rhs, BLOCK_RHS): + rhs_offsets = rhs_base + tl.arange(0, BLOCK_RHS) + rhs_mask = rhs_offsets < n_rhs + acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + nnz_offsets = idx + tl.arange(0, BLOCK_NNZ) + nnz_mask = nnz_offsets < end + a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) + col = tl.load(indices_ptr + nnz_offsets, mask=nnz_mask, other=0) + if LOWER: + dep_mask = nnz_mask & (col < row) + else: + dep_mask = nnz_mask & (col > row) + + for k in range(BLOCK_NNZ): + if dep_mask[k]: + dep_col = col[k] + while tl.load(ready_ptr + dep_col) == 0: + pass + x_ptrs = x_ptr + dep_col * stride_x0 + rhs_offsets + x_vals = tl.load(x_ptrs, mask=rhs_mask, other=0.0) + acc += a[k] * x_vals + + rhs_ptrs = b_ptr + row * stride_b0 + rhs_offsets + rhs = tl.load(rhs_ptrs, mask=rhs_mask, other=0.0) + x_row = (rhs - acc) / diag_safe + x_row = tl.where(x_row == x_row, x_row, 0.0) + out_ptrs = x_ptr + row * stride_x0 + rhs_offsets + tl.store(out_ptrs, x_row, mask=rhs_mask) + + tl.debug_barrier() + tl.store(ready_ptr + row, 1) + row = tl.atomic_add(row_counter_ptr, 1) + + +@triton.jit +def _spsv_csr_cw_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + diag_ri_ptr, + b_ri_ptr, + x_ri_ptr, + ready_ptr, + row_counter_ptr, + n_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + row = tl.atomic_add(row_counter_ptr, 1) + while row < n_rows: + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + if UNIT_DIAG: + diag_re = 1.0 + diag_im = 0.0 + else: + diag_re = tl.load(diag_ri_ptr + row * 2) + diag_im = tl.load(diag_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + diag_re = diag_re.to(tl.float64) + diag_im = diag_im.to(tl.float64) + acc_re = tl.zeros((1,), dtype=tl.float64) + acc_im = tl.zeros((1,), dtype=tl.float64) + else: + diag_re = diag_re.to(tl.float32) + diag_im = diag_im.to(tl.float32) + acc_re = tl.zeros((1,), dtype=tl.float32) + acc_im = tl.zeros((1,), dtype=tl.float32) + + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + if LOWER: + dep_mask = mask & (col < row) + else: + dep_mask = mask & (col > row) + + for k in range(BLOCK_NNZ): + if dep_mask[k]: + dep_col = col[k] + while tl.load(ready_ptr + dep_col) == 0: + pass + x_re = tl.load(x_ri_ptr + dep_col * 2) + x_im = tl.load(x_ri_ptr + dep_col * 2 + 1) + if USE_FP64_ACC: + x_re = x_re.to(tl.float64) + x_im = x_im.to(tl.float64) + else: + x_re = x_re.to(tl.float32) + x_im = x_im.to(tl.float32) + acc_re += a_re[k] * x_re - a_im[k] * x_im + acc_im += a_re[k] * x_im + a_im[k] * x_re + + rhs_re = tl.load(b_ri_ptr + row * 2) + rhs_im = tl.load(b_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + + num_re = rhs_re - acc_re + num_im = rhs_im - acc_im + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + x_re_out = (num_re * diag_re + num_im * diag_im) / den_safe + x_im_out = (num_im * diag_re - num_re * diag_im) / den_safe + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + + tl.store(x_ri_ptr + row * 2, x_re_out) + tl.store(x_ri_ptr + row * 2 + 1, x_im_out) + tl.debug_barrier() + tl.store(ready_ptr + row, 1) + row = tl.atomic_add(row_counter_ptr, 1) + + @triton.jit def _spsv_csr_transpose_push_kernel( data_ptr, @@ -690,6 +1168,145 @@ def _spsv_csr_transpose_push_kernel_complex( tl.atomic_add(residual_ri_ptr + col * 2 + 1, -prod_im, mask=target_mask) +@triton.jit +def _spsv_csr_transpose_cw_kernel( + data_ptr, + indices_ptr, + indptr_ptr, + diag_ptr, + indegree_ptr, + residual_ptr, + x_ptr, + row_counter_ptr, + n_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + row = tl.atomic_add(row_counter_ptr, 1) + while row < n_rows: + ready_value = 0 if UNIT_DIAG else 1 + while tl.load(indegree_ptr + row) != ready_value: + pass + + rhs = tl.load(residual_ptr + row) + if UNIT_DIAG: + diag = rhs * 0 + 1.0 + else: + diag = tl.load(diag_ptr + row) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + x_row = rhs / diag_safe + x_row = tl.where(x_row == x_row, x_row, 0.0) + tl.store(x_ptr + row, x_row) + + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + a = tl.load(data_ptr + offsets, mask=mask, other=0.0) + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + if LOWER: + target_mask = mask & (col < row) + else: + target_mask = mask & (col > row) + tl.atomic_add(residual_ptr + col, -a * x_row, mask=target_mask) + tl.atomic_add(indegree_ptr + col, -1, mask=target_mask) + row = tl.atomic_add(row_counter_ptr, 1) + + +@triton.jit +def _spsv_csr_transpose_cw_kernel_complex( + data_ri_ptr, + indices_ptr, + indptr_ptr, + diag_ri_ptr, + indegree_ptr, + residual_ri_ptr, + x_ri_ptr, + row_counter_ptr, + n_rows, + BLOCK_NNZ: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + LOWER: tl.constexpr, + UNIT_DIAG: tl.constexpr, + CONJ_TRANS: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + row = tl.atomic_add(row_counter_ptr, 1) + while row < n_rows: + ready_value = 0 if UNIT_DIAG else 1 + while tl.load(indegree_ptr + row) != ready_value: + pass + + rhs_re = tl.load(residual_ri_ptr + row * 2) + rhs_im = tl.load(residual_ri_ptr + row * 2 + 1) + if USE_FP64_ACC: + rhs_re = rhs_re.to(tl.float64) + rhs_im = rhs_im.to(tl.float64) + else: + rhs_re = rhs_re.to(tl.float32) + rhs_im = rhs_im.to(tl.float32) + + if UNIT_DIAG: + diag_re = rhs_re * 0 + 1.0 + diag_im = rhs_im * 0 + else: + diag_re = tl.load(diag_ri_ptr + row * 2) + diag_im = tl.load(diag_ri_ptr + row * 2 + 1) + if CONJ_TRANS: + diag_im = -diag_im + if USE_FP64_ACC: + diag_re = diag_re.to(tl.float64) + diag_im = diag_im.to(tl.float64) + else: + diag_re = diag_re.to(tl.float32) + diag_im = diag_im.to(tl.float32) + + den = diag_re * diag_re + diag_im * diag_im + den_safe = tl.where(den < (DIAG_EPS * DIAG_EPS), 1.0, den) + x_re_out = (rhs_re * diag_re + rhs_im * diag_im) / den_safe + x_im_out = (rhs_im * diag_re - rhs_re * diag_im) / den_safe + x_re_out = tl.where(x_re_out == x_re_out, x_re_out, 0.0) + x_im_out = tl.where(x_im_out == x_im_out, x_im_out, 0.0) + + offs1 = tl.arange(0, 1) + tl.store(x_ri_ptr + row * 2 + offs1, x_re_out) + tl.store(x_ri_ptr + row * 2 + 1 + offs1, x_im_out) + + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + for seg in range(MAX_SEGMENTS): + idx = start + seg * BLOCK_NNZ + offsets = idx + tl.arange(0, BLOCK_NNZ) + mask = offsets < end + col = tl.load(indices_ptr + offsets, mask=mask, other=0) + a_re = tl.load(data_ri_ptr + offsets * 2, mask=mask, other=0.0) + a_im = tl.load(data_ri_ptr + offsets * 2 + 1, mask=mask, other=0.0) + if CONJ_TRANS: + a_im = -a_im + if USE_FP64_ACC: + a_re = a_re.to(tl.float64) + a_im = a_im.to(tl.float64) + else: + a_re = a_re.to(tl.float32) + a_im = a_im.to(tl.float32) + if LOWER: + target_mask = mask & (col < row) + else: + target_mask = mask & (col > row) + prod_re = a_re * x_re_out - a_im * x_im_out + prod_im = a_re * x_im_out + a_im * x_re_out + tl.atomic_add(residual_ri_ptr + col * 2, -prod_re, mask=target_mask) + tl.atomic_add(residual_ri_ptr + col * 2 + 1, -prod_im, mask=target_mask) + tl.atomic_add(indegree_ptr + col, -1, mask=target_mask) + row = tl.atomic_add(row_counter_ptr, 1) + + @triton.jit def _spsv_coo_level_kernel_real( data_ptr, @@ -946,6 +1563,168 @@ def _triton_spsv_csr_vector_complex( return x +def _triton_spsv_csr_cw_vector( + data, + indices, + indptr, + diag, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + block_nnz_use=None, + max_segments_use=None, + worker_count=None, + matrix_stats=None, +): + if b_vec.ndim == 1: + b_mat = b_vec.unsqueeze(1).contiguous() + else: + b_mat = b_vec.contiguous() + x = torch.zeros_like(b_mat) + ready = torch.zeros(n_rows, dtype=torch.int32, device=b_mat.device) + row_counter = torch.zeros(1, dtype=torch.int32, device=b_mat.device) + n_rhs = int(b_mat.shape[1]) + if n_rows == 0 or n_rhs == 0: + return x.squeeze(1) if b_vec.ndim == 1 else x + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + matrix_stats = matrix_stats or {} + block_rhs = _choose_spsv_block_rhs(n_rhs, matrix_stats, complex_mode=False) + if worker_count is None: + worker_count = _cw_worker_count( + n_rows, + int(matrix_stats.get("max_frontier", n_rows)), + float(matrix_stats.get("avg_nnz_per_row", 0.0)), + n_rhs, + ) + grid = (worker_count,) + _spsv_csr_cw_kernel[grid]( + data, + indices, + indptr, + diag, + b_mat, + x, + ready, + row_counter, + n_rows, + n_rhs, + b_mat.stride(0), + x.stride(0), + BLOCK_RHS=block_rhs, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + DIAG_EPS=diag_eps, + ) + return x.squeeze(1) if b_vec.ndim == 1 else x + + +def _triton_spsv_csr_cw_vector_complex( + data, + indices, + indptr, + diag, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + block_nnz_use=None, + max_segments_use=None, + worker_count=None, + matrix_stats=None, +): + if b_vec.ndim != 1: + cols = [] + for j in range(b_vec.shape[1]): + cols.append( + _triton_spsv_csr_cw_vector_complex( + data, + indices, + indptr, + diag, + b_vec[:, j].contiguous(), + n_rows, + lower=lower, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + worker_count=worker_count, + matrix_stats=matrix_stats, + ) + ) + return torch.stack(cols, dim=1) + + x = torch.zeros_like(b_vec) + ready = torch.zeros(n_rows, dtype=torch.int32, device=b_vec.device) + row_counter = torch.zeros(1, dtype=torch.int32, device=b_vec.device) + if n_rows == 0: + return x + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + if data.layout != torch.strided: + data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) + data_strided.copy_(data) + else: + data_strided = data.contiguous() + if diag.layout != torch.strided: + diag_strided = torch.empty(diag.shape, dtype=diag.dtype, device=diag.device) + diag_strided.copy_(diag) + else: + diag_strided = diag.contiguous() + + data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() + diag_ri = torch.view_as_real(diag_strided).reshape(-1).contiguous() + b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() + component_dtype = _component_dtype_for_complex(data.dtype) + use_fp64 = component_dtype == torch.float64 + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + + if worker_count is None: + matrix_stats = matrix_stats or {} + worker_count = _cw_worker_count( + n_rows, + int(matrix_stats.get("max_frontier", n_rows)), + float(matrix_stats.get("avg_nnz_per_row", 0.0)), + 1, + ) + grid = (worker_count,) + _spsv_csr_cw_kernel_complex[grid]( + data_ri, + indices, + indptr, + diag_ri, + b_ri, + x_ri, + ready, + row_counter, + n_rows, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + return x + + def _triton_spsv_csr_transpose_push_vector( data, indices, @@ -1065,6 +1844,146 @@ def _triton_spsv_csr_transpose_push_vector_complex( return x +def _triton_spsv_csr_transpose_cw_vector( + data, + indices, + indptr, + diag, + indegree_init, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + block_nnz_use=None, + max_segments_use=None, + worker_count=None, + matrix_stats=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + residual = b_vec.clone() + indegree = indegree_init.clone() + row_counter = torch.zeros(1, dtype=torch.int32, device=b_vec.device) + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + if worker_count is None: + matrix_stats = matrix_stats or {} + worker_count = _cw_worker_count( + n_rows, + int(matrix_stats.get("max_frontier", n_rows)), + float(matrix_stats.get("avg_nnz_per_row", 0.0)), + 1, + ) + grid = (worker_count,) + _spsv_csr_transpose_cw_kernel[grid]( + data, + indices, + indptr, + diag, + indegree, + residual, + x, + row_counter, + n_rows, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + DIAG_EPS=diag_eps, + ) + return x + + +def _triton_spsv_csr_transpose_cw_vector_complex( + data, + indices, + indptr, + diag, + indegree_init, + b_vec, + n_rows, + lower=True, + unit_diagonal=False, + conjugate=False, + block_nnz=None, + max_segments=None, + diag_eps=1e-12, + block_nnz_use=None, + max_segments_use=None, + worker_count=None, + matrix_stats=None, +): + x = torch.zeros_like(b_vec) + if n_rows == 0: + return x + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( + indptr, block_nnz=block_nnz, max_segments=max_segments + ) + + if data.layout != torch.strided: + data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) + data_strided.copy_(data) + else: + data_strided = data.contiguous() + if diag.layout != torch.strided: + diag_strided = torch.empty(diag.shape, dtype=diag.dtype, device=diag.device) + diag_strided.copy_(diag) + else: + diag_strided = diag.contiguous() + + residual_work = b_vec.contiguous().clone() + indegree = indegree_init.clone() + row_counter = torch.zeros(1, dtype=torch.int32, device=b_vec.device) + data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() + diag_ri = torch.view_as_real(diag_strided).reshape(-1).contiguous() + residual_ri = torch.view_as_real(residual_work).reshape(-1).contiguous() + component_dtype = _component_dtype_for_complex(data.dtype) + use_fp64 = component_dtype == torch.float64 + if component_dtype == torch.float16: + x_ri_work = torch.zeros((n_rows, 2), dtype=torch.float32, device=b_vec.device) + x_ri = x_ri_work.reshape(-1).contiguous() + else: + x_ri = torch.view_as_real(x.contiguous()).reshape(-1).contiguous() + + if worker_count is None: + matrix_stats = matrix_stats or {} + worker_count = _cw_worker_count( + n_rows, + int(matrix_stats.get("max_frontier", n_rows)), + float(matrix_stats.get("avg_nnz_per_row", 0.0)), + 1, + ) + grid = (worker_count,) + _spsv_csr_transpose_cw_kernel_complex[grid]( + data_ri, + indices, + indptr, + diag_ri, + indegree, + residual_ri, + x_ri, + row_counter, + n_rows, + BLOCK_NNZ=block_nnz_use, + MAX_SEGMENTS=max_segments_use, + LOWER=lower, + UNIT_DIAG=unit_diagonal, + CONJ_TRANS=conjugate, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + if component_dtype == torch.float16: + return torch.view_as_complex(x_ri_work.contiguous()) + return x + + def _choose_transpose_family_launch_config(indptr, block_nnz=None, max_segments=None): if block_nnz is not None or max_segments is not None: return _auto_spsv_launch_config(indptr, block_nnz=block_nnz, max_segments=max_segments) @@ -1292,8 +2211,13 @@ def flagsparse_spsv_csr( shape, lower, transpose, + unit_diagonal, ) + rhs_cols = 1 if b.ndim == 1 else int(b.shape[1]) + solve_plan = _select_spsv_runtime_plan( + solve_plan, rhs_cols, data.dtype, trans_mode + ) solve_kind = solve_plan["solve_kind"] kernel_data = solve_plan["kernel_data"] kernel_indices64 = solve_plan["kernel_indices64"] @@ -1304,6 +2228,11 @@ def flagsparse_spsv_csr( transpose_conjugate = solve_plan["transpose_conjugate"] default_block_nnz = solve_plan["default_block_nnz"] default_max_segments = solve_plan["default_max_segments"] + transpose_diag = solve_plan.get("transpose_diag") + transpose_indegree_init = solve_plan.get("transpose_indegree_init") + cw_diag = solve_plan.get("cw_diag") + cw_worker_count = solve_plan.get("cw_worker_count") + matrix_stats = solve_plan.get("matrix_stats", {}) kernel_indices = kernel_indices32 kernel_indptr = kernel_indptr64 compute_dtype = data.dtype @@ -1330,15 +2259,28 @@ def flagsparse_spsv_csr( data_in = kernel_data.to(torch.float64) b_in = b.to(torch.float64) - if solve_kind == "transpose_push": + if solve_kind in ("transpose_push", "transpose_cw"): if block_nnz is None and max_segments is None: block_nnz_use, max_segments_use = default_block_nnz, default_max_segments else: block_nnz_use, max_segments_use = _choose_transpose_family_launch_config( kernel_indptr, block_nnz=block_nnz, max_segments=max_segments ) - vec_real = _triton_spsv_csr_transpose_push_vector - vec_complex = _triton_spsv_csr_transpose_push_vector_complex + if solve_kind == "transpose_cw": + vec_real = _triton_spsv_csr_transpose_cw_vector + vec_complex = _triton_spsv_csr_transpose_cw_vector_complex + else: + vec_real = _triton_spsv_csr_transpose_push_vector + vec_complex = _triton_spsv_csr_transpose_push_vector_complex + elif solve_kind == "csr_cw": + if block_nnz is None and max_segments is None: + block_nnz_use, max_segments_use = default_block_nnz, default_max_segments + else: + block_nnz_use, max_segments_use = _auto_spsv_launch_config( + kernel_indptr, block_nnz=block_nnz, max_segments=max_segments + ) + vec_real = _triton_spsv_csr_cw_vector + vec_complex = _triton_spsv_csr_cw_vector_complex else: if block_nnz is None and max_segments is None: block_nnz_use, max_segments_use = default_block_nnz, default_max_segments @@ -1353,6 +2295,22 @@ def flagsparse_spsv_csr( if return_time: torch.cuda.synchronize() t0 = time.perf_counter() + transpose_diag_in = transpose_diag + if transpose_diag is not None and compute_dtype != data.dtype: + transpose_diag_in = transpose_diag.to(compute_dtype) + cw_diag_in = cw_diag + if cw_diag is not None and compute_dtype != data.dtype: + cw_diag_in = cw_diag.to(compute_dtype) + worker_count_use = cw_worker_count + rhs_cols = 1 if b_in.ndim == 1 else int(b_in.shape[1]) + matrix_stats_use = dict(matrix_stats) + if solve_kind in ("csr_cw", "transpose_cw"): + worker_count_use = _cw_worker_count( + n_rows, + int(matrix_stats_use.get("max_frontier", cw_worker_count if cw_worker_count is not None else n_rows)), + float(matrix_stats_use.get("avg_nnz_per_row", default_block_nnz)), + rhs_cols, + ) if b_in.ndim == 1: if torch.is_complex(data_in): if solve_kind == "transpose_push": @@ -1372,6 +2330,43 @@ def flagsparse_spsv_csr( block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, ) + elif solve_kind == "transpose_cw": + x = vec_complex( + data_in, + kernel_indices, + kernel_indptr, + transpose_diag_in, + transpose_indegree_init, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + conjugate=transpose_conjugate, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + matrix_stats=matrix_stats_use, + ) + elif solve_kind == "csr_cw": + x = vec_complex( + data_in, + kernel_indices, + kernel_indptr, + cw_diag_in, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, + ) else: x = vec_complex( data_in, @@ -1405,6 +2400,42 @@ def flagsparse_spsv_csr( block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, ) + elif solve_kind == "transpose_cw": + x = vec_real( + data_in, + kernel_indices, + kernel_indptr, + transpose_diag_in, + transpose_indegree_init, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + matrix_stats=matrix_stats_use, + ) + elif solve_kind == "csr_cw": + x = vec_real( + data_in, + kernel_indices, + kernel_indptr, + cw_diag_in, + b_in, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, + ) else: x = vec_real( data_in, @@ -1445,6 +2476,47 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, ) ) + elif solve_kind == "transpose_cw": + cols.append( + vec_complex( + data_in, + kernel_indices, + kernel_indptr, + transpose_diag_in, + transpose_indegree_init, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + conjugate=transpose_conjugate, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + matrix_stats=matrix_stats_use, + ) + ) + elif solve_kind == "csr_cw": + cols.append( + vec_complex( + data_in, + kernel_indices, + kernel_indptr, + cw_diag_in, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, + ) + ) else: cols.append( vec_complex( @@ -1482,6 +2554,47 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, ) ) + elif solve_kind == "transpose_cw": + cols.append( + vec_real( + data_in, + kernel_indices, + kernel_indptr, + transpose_diag_in, + transpose_indegree_init, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, + ) + ) + elif solve_kind == "csr_cw": + cols.append( + vec_real( + data_in, + kernel_indices, + kernel_indptr, + cw_diag_in, + bj, + n_rows, + lower=lower_eff, + unit_diagonal=unit_diagonal, + block_nnz=block_nnz, + max_segments=max_segments, + diag_eps=diag_eps, + block_nnz_use=block_nnz_use, + max_segments_use=max_segments_use, + worker_count=worker_count_use, + matrix_stats=matrix_stats_use, + ) + ) else: cols.append( vec_real( @@ -1530,7 +2643,6 @@ def _analyze_spsv_csr( clear_cache=False, return_time=False, ): - del unit_diagonal if clear_cache: _clear_spsv_csr_preprocess_cache() if return_time: @@ -1544,6 +2656,7 @@ def _analyze_spsv_csr( shape, lower, transpose, + unit_diagonal, ) if return_time: torch.cuda.synchronize() From bf259c249a569dcd3ce8b4c327f17add02788797 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 28 Apr 2026 12:33:36 +0800 Subject: [PATCH 19/22] opt --- src/flagsparse/sparse_operations/spsv.py | 20 ++++++++++++++++---- tests/test_spsm.py | 2 -- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 203e78f..8c8f54c 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -46,6 +46,10 @@ def _spsv_env_flag(name, default="0"): SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128 = _spsv_env_flag( "FLAGSPARSE_SPSV_PROMOTE_TRANSPOSE_COMPLEX64_TO_COMPLEX128", "0" ) +SPSV_ENABLE_CSR_CW = _spsv_env_flag("FLAGSPARSE_SPSV_ENABLE_CSR_CW", "0") +SPSV_ENABLE_TRANSPOSE_CW = _spsv_env_flag("FLAGSPARSE_SPSV_ENABLE_TRANSPOSE_CW", "0") +SPSV_ENABLE_LEVEL_FRONTIERS = _spsv_env_flag("FLAGSPARSE_SPSV_ENABLE_LEVEL_FRONTIERS", "0") +SPSV_ENABLE_REVERSE_FRONTIERS = _spsv_env_flag("FLAGSPARSE_SPSV_ENABLE_REVERSE_FRONTIERS", "0") _SPSV_CSR_PREPROCESS_CACHE = OrderedDict() _SPSV_CSR_PREPROCESS_CACHE_SIZE = 8 @@ -528,8 +532,10 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "kernel_indices32": indices64.to(torch.int32), "kernel_indptr64": indptr64, "lower_eff": lower, - "launch_groups": _build_spsv_frontiers( - indptr64, indices64, levels, lower=lower + "launch_groups": ( + _build_spsv_frontiers(indptr64, indices64, levels, lower=lower) + if SPSV_ENABLE_LEVEL_FRONTIERS + else levels ), "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, @@ -539,6 +545,8 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "route_name": "csr_levels", "alt_plan": None, } + if not SPSV_ENABLE_CSR_CW: + return levels_plan data_sorted, indices_sorted64, indptr_sorted64 = _sort_csr_rows( data, indices64, indptr64, n_rows, n_cols, lower=lower ) @@ -583,8 +591,10 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "kernel_indices32": indices64.to(torch.int32), "kernel_indptr64": indptr64, "lower_eff": lower, - "launch_groups": _build_spsv_reverse_frontiers( - indptr64, indices64, levels, lower=lower + "launch_groups": ( + _build_spsv_reverse_frontiers(indptr64, indices64, levels, lower=lower) + if SPSV_ENABLE_REVERSE_FRONTIERS + else list(reversed(levels)) ), "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, @@ -594,6 +604,8 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "route_name": "transpose_push", "alt_plan": None, } + if not SPSV_ENABLE_TRANSPOSE_CW: + return push_plan diag, indegree_init = _prepare_spsv_transpose_cw_metadata( data, indices64, indptr64, n_rows, lower, unit_diagonal=unit_diagonal diff --git a/tests/test_spsm.py b/tests/test_spsm.py index 78ff983..ab15232 100644 --- a/tests/test_spsm.py +++ b/tests/test_spsm.py @@ -441,8 +441,6 @@ def _run_one_spsm_case(data, indices, indptr, shape, value_dtype, index_dtype, n status = "REF_FAIL" else: status = "FAIL" - if not ok_res and status == "PASS": - status = "FAIL" return { "format": fmt, From 06053a768f471a27cb22771f27095fbc3492f2f6 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 28 Apr 2026 17:19:57 +0800 Subject: [PATCH 20/22] opt --- src/flagsparse/sparse_operations/spsm.py | 607 +++++++++++++++++++---- src/flagsparse/sparse_operations/spsv.py | 70 ++- 2 files changed, 577 insertions(+), 100 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsm.py b/src/flagsparse/sparse_operations/spsm.py index fc5d805..36066bc 100644 --- a/src/flagsparse/sparse_operations/spsm.py +++ b/src/flagsparse/sparse_operations/spsm.py @@ -1,6 +1,7 @@ """Sparse triangular matrix-matrix solve (SpSM) for CSR/COO.""" from collections import OrderedDict +import os from ._common import * @@ -13,6 +14,12 @@ ("coo", torch.float32, torch.int32), ("coo", torch.float64, torch.int32), ) +def _spsm_env_flag(name, default="0"): + return str(os.environ.get(name, default)).lower() in ("1", "true", "yes", "on") + + +SPSM_ENABLE_TILED_KERNEL = _spsm_env_flag("FLAGSPARSE_SPSM_ENABLE_TILED_KERNEL", "1") +SPSM_TILED_MIN_RHS = 16 _SPSM_PREPROCESS_CACHE = OrderedDict() _SPSM_PREPROCESS_CACHE_SIZE = 8 @@ -167,11 +174,176 @@ def _spsm_preprocess_cache_key(fmt_name, tensors, shape, lower, unit_diagonal): ) +def _prepare_spsm_diag(data, indices64, indptr64, n_rows, unit_diagonal=False): + diag = torch.ones(n_rows, dtype=data.dtype, device=data.device) + if unit_diagonal or n_rows == 0 or data.numel() == 0: + return diag + row_ids = torch.repeat_interleave( + torch.arange(n_rows, device=data.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + diag_mask = indices64 == row_ids + if bool(torch.any(diag_mask).item()): + diag.scatter_(0, row_ids[diag_mask], data[diag_mask]) + return diag + + +def _prepare_spsm_kernel_row_ptr(indptr64): + if indptr64.numel() == 0: + return indptr64.to(torch.int32) + if int(indptr64[-1].item()) <= _INDEX_LIMIT_INT32: + return indptr64.to(torch.int32) + return indptr64 + + +def _csr_row_ids_from_indptr(indptr64, n_rows): + if n_rows == 0 or indptr64.numel() <= 1: + return torch.empty(0, dtype=torch.int64, device=indptr64.device) + return torch.repeat_interleave( + torch.arange(n_rows, device=indptr64.device, dtype=torch.int64), + indptr64[1:] - indptr64[:-1], + ) + + +def _csr_rows_sorted_by_col(indices64, indptr64, n_rows): + if indices64.numel() <= 1: + return True + row_ids = _csr_row_ids_from_indptr(indptr64, n_rows) + if row_ids.numel() <= 1: + return True + same_row = row_ids[1:] == row_ids[:-1] + if not bool(torch.any(same_row).item()): + return True + return bool(torch.all(indices64[1:][same_row] >= indices64[:-1][same_row]).item()) + + +def _prepare_spsm_csr_sorted_view(data, indices64, indptr64, n_rows, n_cols): + row_ids = _csr_row_ids_from_indptr(indptr64, n_rows) + if data.numel() <= 1: + return data, indices64, row_ids + if _csr_rows_sorted_by_col(indices64, indptr64, n_rows): + return data, indices64, row_ids + key = row_ids * max(1, n_cols) + indices64 + try: + order = torch.argsort(key, stable=True) + except TypeError: + order = torch.argsort(key) + return data[order].contiguous(), indices64[order].contiguous(), row_ids + + +def _prepare_spsm_csr_dependency_bounds(indices64, indptr64, n_rows, lower, row_ids=None): + row_begin = indptr64[:-1].clone() + row_end = indptr64[1:].clone() + if n_rows == 0 or indices64.numel() == 0: + return row_begin, row_begin if lower else row_end + if row_ids is None: + row_ids = _csr_row_ids_from_indptr(indptr64, n_rows) + dep_mask = indices64 < row_ids if lower else indices64 > row_ids + dep_counts = torch.bincount(row_ids[dep_mask], minlength=n_rows) + if lower: + dep_begin = row_begin + dep_end = row_begin + dep_counts + else: + dep_begin = row_end - dep_counts + dep_end = row_end + return dep_begin, dep_end + + +def _prepare_spsm_csr_dependency_view( + data, indices64, indptr64, n_rows, lower, row_ids=None +): + dep_begin64, dep_end64 = _prepare_spsm_csr_dependency_bounds( + indices64, indptr64, n_rows, lower=lower, row_ids=row_ids + ) + dep_counts = dep_end64 - dep_begin64 + dep_ptr64 = torch.zeros(n_rows + 1, dtype=torch.int64, device=indptr64.device) + if n_rows > 0: + dep_ptr64[1:] = torch.cumsum(dep_counts, dim=0) + total_dep_nnz = int(dep_ptr64[-1].item()) if dep_ptr64.numel() > 0 else 0 + if total_dep_nnz == 0: + empty_data = torch.empty(0, dtype=data.dtype, device=data.device) + empty_indices = torch.empty(0, dtype=torch.int64, device=indices64.device) + return empty_data, empty_indices, dep_ptr64 + + if row_ids is None: + row_ids = _csr_row_ids_from_indptr(indptr64, n_rows) + dep_mask = indices64 < row_ids if lower else indices64 > row_ids + dep_data = data[dep_mask].contiguous() + dep_indices64 = indices64[dep_mask].contiguous() + return dep_data, dep_indices64, dep_ptr64 + + +def _alpha_is_one(alpha): + if torch.is_tensor(alpha): + if alpha.numel() != 1: + return False + return bool((alpha.detach().cpu() == 1).item()) + return alpha == 1 or alpha == 1.0 + + +def _alpha_to_host_scalar(alpha): + if torch.is_tensor(alpha): + if alpha.numel() != 1: + raise ValueError("alpha tensor must be scalar") + return float(alpha.detach().cpu().item()) + return float(alpha) + + +def _should_use_spsm_tiled_kernel(fmt_name, rhs, dtype): + if not SPSM_ENABLE_TILED_KERNEL: + return False + if dtype not in (torch.float32, torch.float64): + return False + if rhs.ndim != 2: + return False + n_rhs = int(rhs.shape[1]) + if n_rhs < SPSM_TILED_MIN_RHS: + return False + fmt = str(fmt_name).lower() + if fmt == "csr": + return True + if fmt == "coo": + return True + return False + + +@triton.jit +def _spsm_pack_rhs_work_kernel( + src_ptr, + dst_ptr, + n_rows, + n_rhs, + stride_src0, + stride_dst0, + alpha, + BLOCK_ROWS: tl.constexpr, + BLOCK_RHS: tl.constexpr, + USE_FP64_ACC: tl.constexpr, +): + pid_row = tl.program_id(0) + pid_rhs = tl.program_id(1) + + row_offsets = pid_row * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS) + rhs_offsets = pid_rhs * BLOCK_RHS + tl.arange(0, BLOCK_RHS) + mask = (row_offsets[:, None] < n_rows) & (rhs_offsets[None, :] < n_rhs) + + src_ptrs = src_ptr + row_offsets[:, None] * stride_src0 + rhs_offsets[None, :] + vals = tl.load(src_ptrs, mask=mask, other=0.0) + if USE_FP64_ACC: + vals = vals.to(tl.float64) * alpha + else: + vals = vals.to(tl.float32) * alpha + + dst_ptrs = dst_ptr + row_offsets[:, None] * stride_dst0 + rhs_offsets[None, :] + tl.store(dst_ptrs, vals, mask=mask) + + @triton.jit def _spsm_csr_level_kernel_real( data_ptr, indices_ptr, indptr_ptr, + diag_ptr, b_ptr, x_ptr, rows_ptr, @@ -183,7 +355,6 @@ def _spsm_csr_level_kernel_real( BLOCK_RHS: tl.constexpr, MAX_SEGMENTS: tl.constexpr, LOWER: tl.constexpr, - UNIT_DIAG: tl.constexpr, USE_FP64_ACC: tl.constexpr, DIAG_EPS: tl.constexpr, ): @@ -201,13 +372,8 @@ def _spsm_csr_level_kernel_real( if USE_FP64_ACC: acc = tl.zeros((BLOCK_RHS,), dtype=tl.float64) - diag = tl.zeros((1,), dtype=tl.float64) else: acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) - diag = tl.zeros((1,), dtype=tl.float32) - - if UNIT_DIAG: - diag = diag + 1.0 for seg in range(MAX_SEGMENTS): nnz_offsets = start + seg * BLOCK_NNZ + tl.arange(0, BLOCK_NNZ) @@ -221,7 +387,6 @@ def _spsm_csr_level_kernel_real( a = a.to(tl.float32) solved_mask = col < row if LOWER else col > row - diag_mask = col == row x_ptrs = x_ptr + col[:, None] * stride_x0 + rhs_offsets[None, :] x_mask = nnz_mask[:, None] & rhs_mask[None, :] @@ -234,13 +399,12 @@ def _spsm_csr_level_kernel_real( contrib = tl.where((nnz_mask & solved_mask)[:, None], a[:, None] * x_vals, 0.0) acc += tl.sum(contrib, axis=0) - if not UNIT_DIAG: - diag += tl.sum(tl.where(nnz_mask & diag_mask, a, 0.0), axis=0) - b_ptrs = b_ptr + row * stride_b0 + rhs_offsets rhs = tl.load(b_ptrs, mask=rhs_mask, other=0.0) rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) + diag = tl.load(diag_ptr + row) + diag = diag.to(tl.float64) if USE_FP64_ACC else diag.to(tl.float32) diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) out = (rhs - acc) / diag_safe out = tl.where(out == out, out, 0.0) @@ -254,6 +418,7 @@ def _spsm_coo_level_kernel_real( data_ptr, row_ptr_ptr, col_ptr, + diag_ptr, b_ptr, x_ptr, rows_ptr, @@ -265,7 +430,6 @@ def _spsm_coo_level_kernel_real( BLOCK_RHS: tl.constexpr, MAX_SEGMENTS: tl.constexpr, LOWER: tl.constexpr, - UNIT_DIAG: tl.constexpr, USE_FP64_ACC: tl.constexpr, DIAG_EPS: tl.constexpr, ): @@ -283,13 +447,8 @@ def _spsm_coo_level_kernel_real( if USE_FP64_ACC: acc = tl.zeros((BLOCK_RHS,), dtype=tl.float64) - diag = tl.zeros((1,), dtype=tl.float64) else: acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) - diag = tl.zeros((1,), dtype=tl.float32) - - if UNIT_DIAG: - diag = diag + 1.0 for seg in range(MAX_SEGMENTS): nnz_offsets = start + seg * BLOCK_NNZ + tl.arange(0, BLOCK_NNZ) @@ -303,7 +462,6 @@ def _spsm_coo_level_kernel_real( a = a.to(tl.float32) solved_mask = col < row if LOWER else col > row - diag_mask = col == row x_ptrs = x_ptr + col[:, None] * stride_x0 + rhs_offsets[None, :] x_mask = nnz_mask[:, None] & rhs_mask[None, :] @@ -316,13 +474,12 @@ def _spsm_coo_level_kernel_real( contrib = tl.where((nnz_mask & solved_mask)[:, None], a[:, None] * x_vals, 0.0) acc += tl.sum(contrib, axis=0) - if not UNIT_DIAG: - diag += tl.sum(tl.where(nnz_mask & diag_mask, a, 0.0), axis=0) - b_ptrs = b_ptr + row * stride_b0 + rhs_offsets rhs = tl.load(b_ptrs, mask=rhs_mask, other=0.0) rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) + diag = tl.load(diag_ptr + row) + diag = diag.to(tl.float64) if USE_FP64_ACC else diag.to(tl.float32) diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) out = (rhs - acc) / diag_safe out = tl.where(out == out, out, 0.0) @@ -331,6 +488,81 @@ def _spsm_coo_level_kernel_real( tl.store(out_ptrs, out, mask=rhs_mask) +@triton.jit +def _spsm_csr_tiled_kernel_real( + data_ptr, + indices_ptr, + dep_ptr_ptr, + diag_ptr, + y_ptr, + ready_ptr, + n_rows, + n_rhs, + stride_y0, + ready_stride0, + BLOCK_NNZ: tl.constexpr, + BLOCK_RHS: tl.constexpr, + MAX_SEGMENTS: tl.constexpr, + USE_FP64_ACC: tl.constexpr, + DIAG_EPS: tl.constexpr, +): + row = tl.program_id(0) + pid_rhs = tl.program_id(1) + if row >= n_rows: + return + + rhs_offsets = pid_rhs * BLOCK_RHS + tl.arange(0, BLOCK_RHS) + rhs_mask = rhs_offsets < n_rhs + ready_row_ptr = ready_ptr + pid_rhs * ready_stride0 + + start = tl.load(dep_ptr_ptr + row) + end = tl.load(dep_ptr_ptr + row + 1) + + if USE_FP64_ACC: + acc = tl.zeros((BLOCK_RHS,), dtype=tl.float64) + else: + acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + + for seg in range(MAX_SEGMENTS): + nnz_offsets = start + seg * BLOCK_NNZ + tl.arange(0, BLOCK_NNZ) + nnz_mask = nnz_offsets < end + a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) + col = tl.load(indices_ptr + nnz_offsets, mask=nnz_mask, other=0) + + if USE_FP64_ACC: + a = a.to(tl.float64) + else: + a = a.to(tl.float32) + + for k in range(BLOCK_NNZ): + if nnz_mask[k]: + dep_col = col[k] + while tl.load(ready_row_ptr + dep_col) == 0: + pass + x_ptrs = y_ptr + dep_col * stride_y0 + rhs_offsets + x_vals = tl.load(x_ptrs, mask=rhs_mask, other=0.0) + if USE_FP64_ACC: + x_vals = x_vals.to(tl.float64) + else: + x_vals = x_vals.to(tl.float32) + acc += a[k] * x_vals + + rhs_ptrs = y_ptr + row * stride_y0 + rhs_offsets + rhs = tl.load(rhs_ptrs, mask=rhs_mask, other=0.0) + rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) + + diag = tl.load(diag_ptr + row) + diag = diag.to(tl.float64) if USE_FP64_ACC else diag.to(tl.float32) + diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) + out = (rhs - acc) / diag_safe + out = tl.where(out == out, out, 0.0) + + out_ptrs = y_ptr + row * stride_y0 + rhs_offsets + tl.store(out_ptrs, out, mask=rhs_mask) + tl.debug_barrier() + tl.store(ready_row_ptr + row, 1) + + def _build_spsm_levels(indptr, indices, n_rows, lower=True): if n_rows == 0: return [] @@ -410,8 +642,51 @@ def _auto_spsm_launch_config(indptr, block_nnz=None, max_segments=None): return block_nnz_use, max_segments_use +def _auto_spsm_launch_config_from_row_lengths(row_lengths, block_nnz=None, max_segments=None): + if row_lengths.numel() == 0: + max_nnz_per_row = 0 + else: + max_nnz_per_row = int(row_lengths.max().item()) + + auto_block = block_nnz is None + if block_nnz is None: + if max_nnz_per_row <= 64: + block_nnz_use = 64 + elif max_nnz_per_row <= 256: + block_nnz_use = 128 + elif max_nnz_per_row <= 1024: + block_nnz_use = 256 + elif max_nnz_per_row <= 4096: + block_nnz_use = 512 + else: + block_nnz_use = 1024 + else: + block_nnz_use = int(block_nnz) + if block_nnz_use <= 0: + raise ValueError("block_nnz must be positive") + + required_segments = max((max_nnz_per_row + block_nnz_use - 1) // block_nnz_use, 1) + if max_segments is None: + max_segments_use = required_segments + if auto_block: + while max_segments_use > 2048 and block_nnz_use < 65536: + block_nnz_use *= 2 + max_segments_use = max((max_nnz_per_row + block_nnz_use - 1) // block_nnz_use, 1) + else: + max_segments_use = int(max_segments) + if max_segments_use <= 0: + raise ValueError("max_segments must be positive") + if max_segments_use < required_segments: + raise ValueError( + f"max_segments={max_segments_use} is too small; at least {required_segments} required" + ) + return block_nnz_use, max_segments_use + + def _auto_rhs_block(n_rhs): n_rhs = int(n_rhs) + if n_rhs <= 4: + return 4 if n_rhs <= 8: return 8 if n_rhs <= 16: @@ -421,6 +696,34 @@ def _auto_rhs_block(n_rhs): return 64 +def _pack_spsm_rhs_work_buffer(rhs, alpha): + if rhs.ndim != 2: + raise ValueError("rhs must be 2D") + n_rows, n_rhs = int(rhs.shape[0]), int(rhs.shape[1]) + work = torch.empty_like(rhs, memory_format=torch.contiguous_format) + if n_rows == 0 or n_rhs == 0: + return work + + block_rows = 32 + block_rhs = _auto_rhs_block(n_rhs) + alpha_value = _alpha_to_host_scalar(alpha) + use_fp64 = rhs.dtype == torch.float64 + grid = (triton.cdiv(n_rows, block_rows), triton.cdiv(n_rhs, block_rhs)) + _spsm_pack_rhs_work_kernel[grid]( + rhs, + work, + n_rows, + n_rhs, + rhs.stride(0), + work.stride(0), + alpha_value, + BLOCK_ROWS=block_rows, + BLOCK_RHS=block_rhs, + USE_FP64_ACC=use_fp64, + ) + return work + + def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): if data.numel() == 0: return ( @@ -448,33 +751,62 @@ def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): return data_u, col_u.to(torch.int64), indptr -def _prepare_spsm_csr_system(data, indices64, indptr64, n_rows, lower): - levels = _build_spsm_levels(indptr64, indices64, n_rows, lower=lower) +def _prepare_spsm_csr_system(data, indices64, indptr64, n_rows, lower, unit_diagonal): + sorted_data, sorted_indices64, row_ids = _prepare_spsm_csr_sorted_view( + data, indices64, indptr64, n_rows, n_rows + ) + dep_data, dep_indices64, dep_ptr64 = _prepare_spsm_csr_dependency_view( + sorted_data, sorted_indices64, indptr64, n_rows, lower=lower, row_ids=row_ids + ) + levels = _build_spsm_levels(indptr64, sorted_indices64, n_rows, lower=lower) default_block_nnz, default_max_segments = _auto_spsm_launch_config(indptr64) + tiled_block_nnz, tiled_max_segments = _auto_spsm_launch_config_from_row_lengths( + dep_ptr64[1:] - dep_ptr64[:-1] + ) return { - "kernel_data": data, - "kernel_indices64": indices64, - "kernel_indices32": indices64.to(torch.int32), - "kernel_indptr64": indptr64, + "kernel_data": sorted_data, + "kernel_indices32": sorted_indices64.to(torch.int32), + "kernel_dep_data": dep_data, + "kernel_dep_indices32": dep_indices64.to(torch.int32), + "kernel_dep_ptr": _prepare_spsm_kernel_row_ptr(dep_ptr64), + "kernel_diag": _prepare_spsm_diag( + sorted_data, sorted_indices64, indptr64, n_rows, unit_diagonal=unit_diagonal + ), + "kernel_indptr": _prepare_spsm_kernel_row_ptr(indptr64), "launch_groups": levels, "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, + "tiled_block_nnz": tiled_block_nnz, + "tiled_max_segments": tiled_max_segments, "lower_eff": bool(lower), } -def _prepare_spsm_coo_system(data, row64, col64, n_rows, n_cols, lower): +def _prepare_spsm_coo_system(data, row64, col64, n_rows, n_cols, lower, unit_diagonal): data_u, col_u64, row_ptr = _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols) + dep_data, dep_indices64, dep_ptr64 = _prepare_spsm_csr_dependency_view( + data_u, col_u64, row_ptr, n_rows, lower=lower + ) levels = _build_spsm_levels(row_ptr, col_u64, n_rows, lower=lower) default_block_nnz, default_max_segments = _auto_spsm_launch_config(row_ptr) + tiled_block_nnz, tiled_max_segments = _auto_spsm_launch_config_from_row_lengths( + dep_ptr64[1:] - dep_ptr64[:-1] + ) return { "kernel_data": data_u, - "kernel_cols64": col_u64, "kernel_cols32": col_u64.to(torch.int32), - "kernel_row_ptr64": row_ptr, + "kernel_dep_data": dep_data, + "kernel_dep_indices32": dep_indices64.to(torch.int32), + "kernel_dep_ptr": _prepare_spsm_kernel_row_ptr(dep_ptr64), + "kernel_diag": _prepare_spsm_diag( + data_u, col_u64, row_ptr, n_rows, unit_diagonal=unit_diagonal + ), + "kernel_row_ptr": _prepare_spsm_kernel_row_ptr(row_ptr), "launch_groups": levels, "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, + "tiled_block_nnz": tiled_block_nnz, + "tiled_max_segments": tiled_max_segments, "lower_eff": bool(lower), } @@ -492,7 +824,9 @@ def _resolve_spsm_csr_runtime(data, indices, indptr, B, shape, lower, unit_diago ) solve_plan = _spsm_cache_get(_SPSM_PREPROCESS_CACHE, cache_key) if solve_plan is None: - solve_plan = _prepare_spsm_csr_system(data, indices64, indptr64, n_rows, lower) + solve_plan = _prepare_spsm_csr_system( + data, indices64, indptr64, n_rows, lower, unit_diagonal + ) _spsm_cache_put(_SPSM_PREPROCESS_CACHE, cache_key, solve_plan, _SPSM_PREPROCESS_CACHE_SIZE) return data, B, n_rows, n_cols, solve_plan @@ -510,7 +844,9 @@ def _resolve_spsm_coo_runtime(data, row, col, B, shape, lower, unit_diagonal, op ) solve_plan = _spsm_cache_get(_SPSM_PREPROCESS_CACHE, cache_key) if solve_plan is None: - solve_plan = _prepare_spsm_coo_system(data, row64, col64, n_rows, n_cols, lower) + solve_plan = _prepare_spsm_coo_system( + data, row64, col64, n_rows, n_cols, lower, unit_diagonal + ) _spsm_cache_put(_SPSM_PREPROCESS_CACHE, cache_key, solve_plan, _SPSM_PREPROCESS_CACHE_SIZE) return data, B, n_rows, n_cols, solve_plan @@ -518,7 +854,8 @@ def _resolve_spsm_coo_runtime(data, row, col, B, shape, lower, unit_diagonal, op def _run_spsm_csr_core( data, indices32, - indptr64, + indptr, + diag, rhs, n_rows, *, @@ -542,11 +879,11 @@ def _run_spsm_csr_core( return x if launch_groups is None: - levels = _build_spsm_levels(indptr64, indices32, n_rows, lower=lower) + levels = _build_spsm_levels(indptr, indices32, n_rows, lower=lower) launch_groups = levels if block_nnz_use is None or max_segments_use is None: block_nnz_use, max_segments_use = _auto_spsm_launch_config( - indptr64, block_nnz=block_nnz, max_segments=max_segments + indptr, block_nnz=block_nnz, max_segments=max_segments ) block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) if block_rhs_use <= 0: @@ -563,7 +900,8 @@ def _run_spsm_csr_core( _spsm_csr_level_kernel_real[grid]( data, indices32, - indptr64, + indptr, + diag, rhs, x, rows_lv, @@ -575,17 +913,78 @@ def _run_spsm_csr_core( BLOCK_RHS=block_rhs_use, MAX_SEGMENTS=max_segments_use, LOWER=lower, - UNIT_DIAG=unit_diagonal, USE_FP64_ACC=use_fp64, DIAG_EPS=diag_eps, ) return x +def _run_spsm_csr_tiled_core( + data, + indices32, + dep_ptr, + diag, + rhs, + n_rows, + *, + alpha=1.0, + block_nnz=None, + max_segments=None, + block_rhs=None, + block_nnz_use=None, + max_segments_use=None, +): + if rhs.ndim != 2: + raise ValueError("rhs must be 2D") + rhs = rhs.contiguous() + if rhs.shape[0] != n_rows: + raise ValueError("rhs first dim must equal n_rows") + n_rhs = int(rhs.shape[1]) + y = _pack_spsm_rhs_work_buffer(rhs, alpha) + if n_rows == 0 or n_rhs == 0: + return y + + if block_nnz_use is None or max_segments_use is None: + block_nnz_use, max_segments_use = _auto_spsm_launch_config_from_row_lengths( + dep_ptr[1:] - dep_ptr[:-1], + block_nnz=block_nnz, + max_segments=max_segments, + ) + block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) + if block_rhs_use <= 0: + raise ValueError("block_rhs must be positive") + + use_fp64 = data.dtype == torch.float64 + diag_eps = 1e-12 if use_fp64 else 1e-6 + rhs_tiles = triton.cdiv(n_rhs, block_rhs_use) + ready = torch.zeros((rhs_tiles, n_rows), dtype=torch.int32, device=rhs.device) + + grid = (n_rows, rhs_tiles) + _spsm_csr_tiled_kernel_real[grid]( + data, + indices32, + dep_ptr, + diag, + y, + ready, + n_rows, + n_rhs, + y.stride(0), + ready.stride(0), + BLOCK_NNZ=block_nnz_use, + BLOCK_RHS=block_rhs_use, + MAX_SEGMENTS=max_segments_use, + USE_FP64_ACC=use_fp64, + DIAG_EPS=diag_eps, + ) + return y + + def _run_spsm_coo_core( data, cols32, - row_ptr64, + row_ptr, + diag, rhs, n_rows, *, @@ -609,11 +1008,11 @@ def _run_spsm_coo_core( return x if launch_groups is None: - levels = _build_spsm_levels(row_ptr64, cols32, n_rows, lower=lower) + levels = _build_spsm_levels(row_ptr, cols32, n_rows, lower=lower) launch_groups = levels if block_nnz_use is None or max_segments_use is None: block_nnz_use, max_segments_use = _auto_spsm_launch_config( - row_ptr64, block_nnz=block_nnz, max_segments=max_segments + row_ptr, block_nnz=block_nnz, max_segments=max_segments ) block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) if block_rhs_use <= 0: @@ -629,8 +1028,9 @@ def _run_spsm_coo_core( grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use)) _spsm_coo_level_kernel_real[grid]( data, - row_ptr64, + row_ptr, cols32, + diag, rhs, x, rows_lv, @@ -642,13 +1042,14 @@ def _run_spsm_coo_core( BLOCK_RHS=block_rhs_use, MAX_SEGMENTS=max_segments_use, LOWER=lower, - UNIT_DIAG=unit_diagonal, USE_FP64_ACC=use_fp64, DIAG_EPS=diag_eps, ) return x + + def flagsparse_spsm_csr( data, indices, @@ -667,24 +1068,41 @@ def flagsparse_spsm_csr( data, B, n_rows, _n_cols, solve_plan = _resolve_spsm_csr_runtime( data, indices, indptr, B, shape, lower, unit_diagonal, opA, opB, major ) - alpha_t = torch.as_tensor(alpha, dtype=B.dtype, device=B.device) - rhs = alpha_t * B - torch.cuda.synchronize() - t0 = time.perf_counter() - x = _run_spsm_csr_core( - solve_plan["kernel_data"], - solve_plan["kernel_indices32"], - solve_plan["kernel_indptr64"], - rhs, - n_rows, - lower=solve_plan["lower_eff"], - unit_diagonal=unit_diagonal, - launch_groups=solve_plan["launch_groups"], - block_nnz_use=solve_plan["default_block_nnz"], - max_segments_use=solve_plan["default_max_segments"], - ) - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - t0) * 1000.0 + use_tiled_kernel = _should_use_spsm_tiled_kernel("csr", B, data.dtype) + alpha_value = 1.0 if _alpha_is_one(alpha) else _alpha_to_host_scalar(alpha) + rhs = B if use_tiled_kernel else (B if alpha_value == 1.0 else torch.as_tensor(alpha, dtype=B.dtype, device=B.device) * B) + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + if use_tiled_kernel: + x = _run_spsm_csr_tiled_core( + solve_plan["kernel_dep_data"], + solve_plan["kernel_dep_indices32"], + solve_plan["kernel_dep_ptr"], + solve_plan["kernel_diag"], + rhs, + n_rows, + alpha=alpha_value, + block_nnz_use=solve_plan["tiled_block_nnz"], + max_segments_use=solve_plan["tiled_max_segments"], + ) + else: + x = _run_spsm_csr_core( + solve_plan["kernel_data"], + solve_plan["kernel_indices32"], + solve_plan["kernel_indptr"], + solve_plan["kernel_diag"], + rhs, + n_rows, + lower=solve_plan["lower_eff"], + unit_diagonal=unit_diagonal, + launch_groups=solve_plan["launch_groups"], + block_nnz_use=solve_plan["default_block_nnz"], + max_segments_use=solve_plan["default_max_segments"], + ) + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: if out.shape != x.shape or out.dtype != x.dtype: raise ValueError("out shape/dtype must match result") @@ -713,24 +1131,41 @@ def flagsparse_spsm_coo( data, B, n_rows, _n_cols, solve_plan = _resolve_spsm_coo_runtime( data, row, col, B, shape, lower, unit_diagonal, opA, opB, major ) - alpha_t = torch.as_tensor(alpha, dtype=B.dtype, device=B.device) - rhs = alpha_t * B - torch.cuda.synchronize() - t0 = time.perf_counter() - x = _run_spsm_coo_core( - solve_plan["kernel_data"], - solve_plan["kernel_cols32"], - solve_plan["kernel_row_ptr64"], - rhs, - n_rows, - lower=solve_plan["lower_eff"], - unit_diagonal=unit_diagonal, - launch_groups=solve_plan["launch_groups"], - block_nnz_use=solve_plan["default_block_nnz"], - max_segments_use=solve_plan["default_max_segments"], - ) - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - t0) * 1000.0 + use_tiled_kernel = _should_use_spsm_tiled_kernel("coo", B, data.dtype) + alpha_value = 1.0 if _alpha_is_one(alpha) else _alpha_to_host_scalar(alpha) + rhs = B if use_tiled_kernel else (B if alpha_value == 1.0 else torch.as_tensor(alpha, dtype=B.dtype, device=B.device) * B) + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() + if use_tiled_kernel: + x = _run_spsm_csr_tiled_core( + solve_plan["kernel_dep_data"], + solve_plan["kernel_dep_indices32"], + solve_plan["kernel_dep_ptr"], + solve_plan["kernel_diag"], + rhs, + n_rows, + alpha=alpha_value, + block_nnz_use=solve_plan["tiled_block_nnz"], + max_segments_use=solve_plan["tiled_max_segments"], + ) + else: + x = _run_spsm_coo_core( + solve_plan["kernel_data"], + solve_plan["kernel_cols32"], + solve_plan["kernel_row_ptr"], + solve_plan["kernel_diag"], + rhs, + n_rows, + lower=solve_plan["lower_eff"], + unit_diagonal=unit_diagonal, + launch_groups=solve_plan["launch_groups"], + block_nnz_use=solve_plan["default_block_nnz"], + max_segments_use=solve_plan["default_max_segments"], + ) + if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 if out is not None: if out.shape != x.shape or out.dtype != x.dtype: raise ValueError("out shape/dtype must match result") @@ -758,12 +1193,13 @@ def _analyze_spsm_csr( ): if clear_cache: _clear_spsm_preprocess_cache() - torch.cuda.synchronize() - t0 = time.perf_counter() + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() _resolve_spsm_csr_runtime(data, indices, indptr, B, shape, lower, unit_diagonal, opA, opB, major) - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - t0) * 1000.0 if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 return elapsed_ms return None @@ -785,12 +1221,13 @@ def _analyze_spsm_coo( ): if clear_cache: _clear_spsm_preprocess_cache() - torch.cuda.synchronize() - t0 = time.perf_counter() + if return_time: + torch.cuda.synchronize() + t0 = time.perf_counter() _resolve_spsm_coo_runtime(data, row, col, B, shape, lower, unit_diagonal, opA, opB, major) - torch.cuda.synchronize() - elapsed_ms = (time.perf_counter() - t0) * 1000.0 if return_time: + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - t0) * 1000.0 return elapsed_ms return None diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 8c8f54c..9783cbe 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -507,20 +507,65 @@ def _score_transpose_cw(matrix_stats, n_rhs, complex_mode): return score +def _nontrans_cw_eligible(matrix_stats, n_rhs, complex_mode): + if complex_mode: + return False + if n_rhs > 2: + return False + if int(matrix_stats.get("num_levels", 0)) < 256: + return False + if float(matrix_stats.get("avg_frontier", 0.0)) > 24.0: + return False + if float(matrix_stats.get("frontier_ratio", 1.0)) > 0.08: + return False + avg_nnz = float(matrix_stats.get("avg_nnz_per_row", 0.0)) + max_nnz = int(matrix_stats.get("max_nnz_per_row", 0)) + if avg_nnz <= 0.0 or avg_nnz > 256.0: + return False + if max_nnz > 2048: + return False + return True + + +def _transpose_cw_eligible(matrix_stats, n_rhs, complex_mode, trans_mode): + if trans_mode != "T": + return False + if complex_mode: + return False + if n_rhs != 1: + return False + if int(matrix_stats.get("num_levels", 0)) < 512: + return False + if float(matrix_stats.get("avg_frontier", 0.0)) > 16.0: + return False + if float(matrix_stats.get("frontier_ratio", 1.0)) > 0.04: + return False + avg_nnz = float(matrix_stats.get("avg_nnz_per_row", 0.0)) + max_nnz = int(matrix_stats.get("max_nnz_per_row", 0)) + if avg_nnz <= 0.0 or avg_nnz > 160.0: + return False + if max_nnz > 1024: + return False + return True + + def _select_nontrans_route(matrix_stats, n_rhs, complex_mode): + if not _nontrans_cw_eligible(matrix_stats, n_rhs, complex_mode): + return "csr_levels" levels_score = _score_nontrans_levels(matrix_stats, n_rhs, complex_mode) cw_score = _score_nontrans_cw(matrix_stats, n_rhs, complex_mode) - return "csr_cw" if cw_score > levels_score else "csr_levels" + return "csr_cw" if cw_score >= (levels_score + 1.0) else "csr_levels" -def _select_transpose_route(matrix_stats, n_rhs, complex_mode): +def _select_transpose_route(matrix_stats, n_rhs, complex_mode, trans_mode): + if not _transpose_cw_eligible(matrix_stats, n_rhs, complex_mode, trans_mode): + return "transpose_push" push_score = _score_transpose_push(matrix_stats, n_rhs, complex_mode) cw_score = _score_transpose_cw(matrix_stats, n_rhs, complex_mode) - return "transpose_cw" if cw_score > push_score else "transpose_push" + return "transpose_cw" if cw_score >= (push_score + 1.0) else "transpose_push" def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, trans_mode, unit_diagonal): - complex_mode = torch.is_complex(data) if trans_mode == "N": levels = _build_spsv_levels(indptr64, indices64, n_rows, lower=lower) matrix_stats = _build_spsv_matrix_stats(indptr64, levels, n_rows) @@ -572,10 +617,6 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "route_name": "csr_cw", "alt_plan": None, } - preferred_kind = _select_nontrans_route(matrix_stats, 1, complex_mode) - if preferred_kind == "csr_cw": - cw_plan["alt_plan"] = levels_plan - return cw_plan levels_plan["alt_plan"] = cw_plan return levels_plan @@ -633,10 +674,6 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "route_name": "transpose_cw", "alt_plan": None, } - preferred_kind = _select_transpose_route(matrix_stats, 1, complex_mode) - if preferred_kind == "transpose_cw": - cw_plan["alt_plan"] = push_plan - return cw_plan push_plan["alt_plan"] = cw_plan return push_plan @@ -706,9 +743,9 @@ def _select_spsv_runtime_plan(solve_plan, rhs_cols, compute_dtype, trans_mode): if trans_mode == "N": desired = _select_nontrans_route(matrix_stats, rhs_cols, complex_mode) else: - desired = _select_transpose_route(matrix_stats, rhs_cols, complex_mode) - if complex_mode and rhs_cols >= 2: - desired = "transpose_push" + desired = _select_transpose_route( + matrix_stats, rhs_cols, complex_mode, trans_mode + ) if desired == route_name or alt_plan is None: return solve_plan if alt_plan.get("route_name", alt_plan["solve_kind"]) == desired: @@ -2359,6 +2396,7 @@ def flagsparse_spsv_csr( diag_eps=diag_eps, block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, + worker_count=worker_count_use, matrix_stats=matrix_stats_use, ) elif solve_kind == "csr_cw": @@ -2428,6 +2466,7 @@ def flagsparse_spsv_csr( diag_eps=diag_eps, block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, + worker_count=worker_count_use, matrix_stats=matrix_stats_use, ) elif solve_kind == "csr_cw": @@ -2506,6 +2545,7 @@ def flagsparse_spsv_csr( diag_eps=diag_eps, block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, + worker_count=worker_count_use, matrix_stats=matrix_stats_use, ) ) From 9d3eb686aeff5e9d3495fd29750b220720890747 Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Tue, 28 Apr 2026 22:24:22 +0800 Subject: [PATCH 21/22] opt --- src/flagsparse/sparse_operations/spsm.py | 836 +++++++++++------------ src/flagsparse/sparse_operations/spsv.py | 249 ++++--- tests/test_spsm.py | 84 +-- tests/test_spsv.py | 128 ++-- 4 files changed, 651 insertions(+), 646 deletions(-) diff --git a/src/flagsparse/sparse_operations/spsm.py b/src/flagsparse/sparse_operations/spsm.py index 36066bc..b034a56 100644 --- a/src/flagsparse/sparse_operations/spsm.py +++ b/src/flagsparse/sparse_operations/spsm.py @@ -1,7 +1,6 @@ """Sparse triangular matrix-matrix solve (SpSM) for CSR/COO.""" from collections import OrderedDict -import os from ._common import * @@ -14,12 +13,6 @@ ("coo", torch.float32, torch.int32), ("coo", torch.float64, torch.int32), ) -def _spsm_env_flag(name, default="0"): - return str(os.environ.get(name, default)).lower() in ("1", "true", "yes", "on") - - -SPSM_ENABLE_TILED_KERNEL = _spsm_env_flag("FLAGSPARSE_SPSM_ENABLE_TILED_KERNEL", "1") -SPSM_TILED_MIN_RHS = 16 _SPSM_PREPROCESS_CACHE = OrderedDict() _SPSM_PREPROCESS_CACHE_SIZE = 8 @@ -188,6 +181,14 @@ def _prepare_spsm_diag(data, indices64, indptr64, n_rows, unit_diagonal=False): return diag +def _prepare_spsm_inv_diag(diag): + if diag.numel() == 0: + return diag + eps = 1e-12 if diag.dtype == torch.float64 else 1e-6 + safe_diag = torch.where(torch.abs(diag) < eps, torch.ones_like(diag), diag) + return torch.reciprocal(safe_diag) + + def _prepare_spsm_kernel_row_ptr(indptr64): if indptr64.numel() == 0: return indptr64.to(torch.int32) @@ -289,53 +290,45 @@ def _alpha_to_host_scalar(alpha): return float(alpha) -def _should_use_spsm_tiled_kernel(fmt_name, rhs, dtype): - if not SPSM_ENABLE_TILED_KERNEL: - return False - if dtype not in (torch.float32, torch.float64): - return False - if rhs.ndim != 2: - return False - n_rhs = int(rhs.shape[1]) - if n_rhs < SPSM_TILED_MIN_RHS: - return False - fmt = str(fmt_name).lower() - if fmt == "csr": - return True - if fmt == "coo": - return True - return False - - @triton.jit -def _spsm_pack_rhs_work_kernel( - src_ptr, - dst_ptr, - n_rows, +def _spsm_csr_diag_only_kernel_real( + inv_diag_ptr, + b_ptr, + x_ptr, + rows_ptr, + n_level_rows, n_rhs, - stride_src0, - stride_dst0, + stride_b0, + stride_x0, alpha, - BLOCK_ROWS: tl.constexpr, BLOCK_RHS: tl.constexpr, USE_FP64_ACC: tl.constexpr, ): pid_row = tl.program_id(0) pid_rhs = tl.program_id(1) + if pid_row >= n_level_rows: + return - row_offsets = pid_row * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS) + row = tl.load(rows_ptr + pid_row) rhs_offsets = pid_rhs * BLOCK_RHS + tl.arange(0, BLOCK_RHS) - mask = (row_offsets[:, None] < n_rows) & (rhs_offsets[None, :] < n_rhs) + rhs_mask = rhs_offsets < n_rhs - src_ptrs = src_ptr + row_offsets[:, None] * stride_src0 + rhs_offsets[None, :] - vals = tl.load(src_ptrs, mask=mask, other=0.0) + b_ptrs = b_ptr + row * stride_b0 + rhs_offsets + rhs = tl.load(b_ptrs, mask=rhs_mask, other=0.0) if USE_FP64_ACC: - vals = vals.to(tl.float64) * alpha + rhs = rhs.to(tl.float64) + alpha_val = tl.full((BLOCK_RHS,), alpha, tl.float64) else: - vals = vals.to(tl.float32) * alpha + rhs = rhs.to(tl.float32) + alpha_val = tl.full((BLOCK_RHS,), alpha, tl.float32) - dst_ptrs = dst_ptr + row_offsets[:, None] * stride_dst0 + rhs_offsets[None, :] - tl.store(dst_ptrs, vals, mask=mask) + inv_diag = tl.load(inv_diag_ptr + row) + inv_diag = inv_diag.to(tl.float64) if USE_FP64_ACC else inv_diag.to(tl.float32) + out = rhs * alpha_val * inv_diag + out = tl.where(out == out, out, 0.0) + + out_ptrs = x_ptr + row * stride_x0 + rhs_offsets + tl.store(out_ptrs, out, mask=rhs_mask) @triton.jit @@ -343,7 +336,7 @@ def _spsm_csr_level_kernel_real( data_ptr, indices_ptr, indptr_ptr, - diag_ptr, + inv_diag_ptr, b_ptr, x_ptr, rows_ptr, @@ -351,12 +344,11 @@ def _spsm_csr_level_kernel_real( n_rhs, stride_b0, stride_x0, + alpha, BLOCK_NNZ: tl.constexpr, BLOCK_RHS: tl.constexpr, MAX_SEGMENTS: tl.constexpr, - LOWER: tl.constexpr, USE_FP64_ACC: tl.constexpr, - DIAG_EPS: tl.constexpr, ): pid_row = tl.program_id(0) pid_rhs = tl.program_id(1) @@ -386,8 +378,6 @@ def _spsm_csr_level_kernel_real( else: a = a.to(tl.float32) - solved_mask = col < row if LOWER else col > row - x_ptrs = x_ptr + col[:, None] * stride_x0 + rhs_offsets[None, :] x_mask = nnz_mask[:, None] & rhs_mask[None, :] x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0) @@ -396,17 +386,21 @@ def _spsm_csr_level_kernel_real( else: x_vals = x_vals.to(tl.float32) - contrib = tl.where((nnz_mask & solved_mask)[:, None], a[:, None] * x_vals, 0.0) + contrib = tl.where(nnz_mask[:, None], a[:, None] * x_vals, 0.0) acc += tl.sum(contrib, axis=0) b_ptrs = b_ptr + row * stride_b0 + rhs_offsets rhs = tl.load(b_ptrs, mask=rhs_mask, other=0.0) - rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) + if USE_FP64_ACC: + rhs = rhs.to(tl.float64) + alpha_val = tl.full((BLOCK_RHS,), alpha, tl.float64) + else: + rhs = rhs.to(tl.float32) + alpha_val = tl.full((BLOCK_RHS,), alpha, tl.float32) - diag = tl.load(diag_ptr + row) - diag = diag.to(tl.float64) if USE_FP64_ACC else diag.to(tl.float32) - diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) - out = (rhs - acc) / diag_safe + inv_diag = tl.load(inv_diag_ptr + row) + inv_diag = inv_diag.to(tl.float64) if USE_FP64_ACC else inv_diag.to(tl.float32) + out = (rhs * alpha_val - acc) * inv_diag out = tl.where(out == out, out, 0.0) out_ptrs = x_ptr + row * stride_x0 + rhs_offsets @@ -414,11 +408,11 @@ def _spsm_csr_level_kernel_real( @triton.jit -def _spsm_coo_level_kernel_real( +def _spsm_csr_level_kernel_single_segment_real( data_ptr, - row_ptr_ptr, - col_ptr, - diag_ptr, + indices_ptr, + indptr_ptr, + inv_diag_ptr, b_ptr, x_ptr, rows_ptr, @@ -426,12 +420,10 @@ def _spsm_coo_level_kernel_real( n_rhs, stride_b0, stride_x0, + alpha, BLOCK_NNZ: tl.constexpr, BLOCK_RHS: tl.constexpr, - MAX_SEGMENTS: tl.constexpr, - LOWER: tl.constexpr, USE_FP64_ACC: tl.constexpr, - DIAG_EPS: tl.constexpr, ): pid_row = tl.program_id(0) pid_rhs = tl.program_id(1) @@ -442,46 +434,37 @@ def _spsm_coo_level_kernel_real( rhs_offsets = pid_rhs * BLOCK_RHS + tl.arange(0, BLOCK_RHS) rhs_mask = rhs_offsets < n_rhs - start = tl.load(row_ptr_ptr + row) - end = tl.load(row_ptr_ptr + row + 1) + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) + nnz_offsets = start + tl.arange(0, BLOCK_NNZ) + nnz_mask = nnz_offsets < end + a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) + col = tl.load(indices_ptr + nnz_offsets, mask=nnz_mask, other=0) if USE_FP64_ACC: - acc = tl.zeros((BLOCK_RHS,), dtype=tl.float64) + a = a.to(tl.float64) else: - acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) - - for seg in range(MAX_SEGMENTS): - nnz_offsets = start + seg * BLOCK_NNZ + tl.arange(0, BLOCK_NNZ) - nnz_mask = nnz_offsets < end - a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) - col = tl.load(col_ptr + nnz_offsets, mask=nnz_mask, other=0) + a = a.to(tl.float32) - if USE_FP64_ACC: - a = a.to(tl.float64) - else: - a = a.to(tl.float32) - - solved_mask = col < row if LOWER else col > row - - x_ptrs = x_ptr + col[:, None] * stride_x0 + rhs_offsets[None, :] - x_mask = nnz_mask[:, None] & rhs_mask[None, :] - x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0) - if USE_FP64_ACC: - x_vals = x_vals.to(tl.float64) - else: - x_vals = x_vals.to(tl.float32) - - contrib = tl.where((nnz_mask & solved_mask)[:, None], a[:, None] * x_vals, 0.0) - acc += tl.sum(contrib, axis=0) + x_ptrs = x_ptr + col[:, None] * stride_x0 + rhs_offsets[None, :] + x_mask = nnz_mask[:, None] & rhs_mask[None, :] + x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0) + if USE_FP64_ACC: + x_vals = x_vals.to(tl.float64) + alpha_val = tl.full((BLOCK_RHS,), alpha, tl.float64) + else: + x_vals = x_vals.to(tl.float32) + alpha_val = tl.full((BLOCK_RHS,), alpha, tl.float32) + contrib = tl.where(nnz_mask[:, None], a[:, None] * x_vals, 0.0) + acc = tl.sum(contrib, axis=0) b_ptrs = b_ptr + row * stride_b0 + rhs_offsets rhs = tl.load(b_ptrs, mask=rhs_mask, other=0.0) rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) - diag = tl.load(diag_ptr + row) - diag = diag.to(tl.float64) if USE_FP64_ACC else diag.to(tl.float32) - diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) - out = (rhs - acc) / diag_safe + inv_diag = tl.load(inv_diag_ptr + row) + inv_diag = inv_diag.to(tl.float64) if USE_FP64_ACC else inv_diag.to(tl.float32) + out = (rhs * alpha_val - acc) * inv_diag out = tl.where(out == out, out, 0.0) out_ptrs = x_ptr + row * stride_x0 + rhs_offsets @@ -489,78 +472,96 @@ def _spsm_coo_level_kernel_real( @triton.jit -def _spsm_csr_tiled_kernel_real( +def _spsm_csr_level_kernel_staged_real( data_ptr, indices_ptr, - dep_ptr_ptr, - diag_ptr, - y_ptr, - ready_ptr, - n_rows, + indptr_ptr, + inv_diag_ptr, + work_ptr, + rows_ptr, + n_level_rows, n_rhs, - stride_y0, - ready_stride0, + stride_work0, + alpha, BLOCK_NNZ: tl.constexpr, BLOCK_RHS: tl.constexpr, + RHS_TILE_GROUPS: tl.constexpr, MAX_SEGMENTS: tl.constexpr, USE_FP64_ACC: tl.constexpr, - DIAG_EPS: tl.constexpr, ): - row = tl.program_id(0) - pid_rhs = tl.program_id(1) - if row >= n_rows: + pid_row = tl.program_id(0) + pid_rhs_group = tl.program_id(1) + if pid_row >= n_level_rows: return - rhs_offsets = pid_rhs * BLOCK_RHS + tl.arange(0, BLOCK_RHS) - rhs_mask = rhs_offsets < n_rhs - ready_row_ptr = ready_ptr + pid_rhs * ready_stride0 - - start = tl.load(dep_ptr_ptr + row) - end = tl.load(dep_ptr_ptr + row + 1) + row = tl.load(rows_ptr + pid_row) + rhs_group_base = pid_rhs_group * BLOCK_RHS * RHS_TILE_GROUPS + start = tl.load(indptr_ptr + row) + end = tl.load(indptr_ptr + row + 1) if USE_FP64_ACC: - acc = tl.zeros((BLOCK_RHS,), dtype=tl.float64) + inv_diag = tl.load(inv_diag_ptr + row).to(tl.float64) + alpha_val = tl.full((BLOCK_RHS,), alpha, tl.float64) + acc0 = tl.zeros((BLOCK_RHS,), dtype=tl.float64) + acc1 = tl.zeros((BLOCK_RHS,), dtype=tl.float64) + acc2 = tl.zeros((BLOCK_RHS,), dtype=tl.float64) + acc3 = tl.zeros((BLOCK_RHS,), dtype=tl.float64) else: - acc = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + inv_diag = tl.load(inv_diag_ptr + row).to(tl.float32) + alpha_val = tl.full((BLOCK_RHS,), alpha, tl.float32) + acc0 = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + acc1 = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + acc2 = tl.zeros((BLOCK_RHS,), dtype=tl.float32) + acc3 = tl.zeros((BLOCK_RHS,), dtype=tl.float32) for seg in range(MAX_SEGMENTS): nnz_offsets = start + seg * BLOCK_NNZ + tl.arange(0, BLOCK_NNZ) nnz_mask = nnz_offsets < end a = tl.load(data_ptr + nnz_offsets, mask=nnz_mask, other=0.0) col = tl.load(indices_ptr + nnz_offsets, mask=nnz_mask, other=0) - if USE_FP64_ACC: a = a.to(tl.float64) else: a = a.to(tl.float32) - for k in range(BLOCK_NNZ): - if nnz_mask[k]: - dep_col = col[k] - while tl.load(ready_row_ptr + dep_col) == 0: - pass - x_ptrs = y_ptr + dep_col * stride_y0 + rhs_offsets - x_vals = tl.load(x_ptrs, mask=rhs_mask, other=0.0) - if USE_FP64_ACC: - x_vals = x_vals.to(tl.float64) - else: - x_vals = x_vals.to(tl.float32) - acc += a[k] * x_vals - - rhs_ptrs = y_ptr + row * stride_y0 + rhs_offsets - rhs = tl.load(rhs_ptrs, mask=rhs_mask, other=0.0) - rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) - - diag = tl.load(diag_ptr + row) - diag = diag.to(tl.float64) if USE_FP64_ACC else diag.to(tl.float32) - diag_safe = tl.where(tl.abs(diag) < DIAG_EPS, 1.0, diag) - out = (rhs - acc) / diag_safe - out = tl.where(out == out, out, 0.0) - - out_ptrs = y_ptr + row * stride_y0 + rhs_offsets - tl.store(out_ptrs, out, mask=rhs_mask) - tl.debug_barrier() - tl.store(ready_row_ptr + row, 1) + for tile_id in range(RHS_TILE_GROUPS): + rhs_offsets = rhs_group_base + tile_id * BLOCK_RHS + tl.arange(0, BLOCK_RHS) + rhs_mask = rhs_offsets < n_rhs + x_ptrs = work_ptr + col[:, None] * stride_work0 + rhs_offsets[None, :] + x_mask = nnz_mask[:, None] & rhs_mask[None, :] + x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0) + if USE_FP64_ACC: + x_vals = x_vals.to(tl.float64) + else: + x_vals = x_vals.to(tl.float32) + contrib = tl.where(nnz_mask[:, None], a[:, None] * x_vals, 0.0) + tile_acc = tl.sum(contrib, axis=0) + if tile_id == 0: + acc0 += tile_acc + elif tile_id == 1: + acc1 += tile_acc + elif tile_id == 2: + acc2 += tile_acc + else: + acc3 += tile_acc + + for tile_id in range(RHS_TILE_GROUPS): + rhs_offsets = rhs_group_base + tile_id * BLOCK_RHS + tl.arange(0, BLOCK_RHS) + rhs_mask = rhs_offsets < n_rhs + work_ptrs = work_ptr + row * stride_work0 + rhs_offsets + rhs = tl.load(work_ptrs, mask=rhs_mask, other=0.0) + rhs = rhs.to(tl.float64) if USE_FP64_ACC else rhs.to(tl.float32) + if tile_id == 0: + acc_tile = acc0 + elif tile_id == 1: + acc_tile = acc1 + elif tile_id == 2: + acc_tile = acc2 + else: + acc_tile = acc3 + out = (rhs * alpha_val - acc_tile) * inv_diag + out = tl.where(out == out, out, 0.0) + tl.store(work_ptrs, out, mask=rhs_mask) def _build_spsm_levels(indptr, indices, n_rows, lower=True): @@ -600,52 +601,91 @@ def _build_spsm_levels(indptr, indices, n_rows, lower=True): return [torch.tensor(rows, dtype=torch.int32, device=device) for rows in buckets if rows] -def _auto_spsm_launch_config(indptr, block_nnz=None, max_segments=None): - if indptr.numel() <= 1: - max_nnz_per_row = 0 - else: - row_lengths = indptr[1:] - indptr[:-1] - max_nnz_per_row = int(row_lengths.max().item()) - - auto_block = block_nnz is None - if block_nnz is None: - if max_nnz_per_row <= 64: - block_nnz_use = 64 - elif max_nnz_per_row <= 256: - block_nnz_use = 128 - elif max_nnz_per_row <= 1024: - block_nnz_use = 256 - elif max_nnz_per_row <= 4096: - block_nnz_use = 512 - else: - block_nnz_use = 1024 - else: - block_nnz_use = int(block_nnz) - if block_nnz_use <= 0: - raise ValueError("block_nnz must be positive") - - required_segments = max((max_nnz_per_row + block_nnz_use - 1) // block_nnz_use, 1) - if max_segments is None: - max_segments_use = required_segments - if auto_block: - while max_segments_use > 2048 and block_nnz_use < 65536: - block_nnz_use *= 2 - max_segments_use = max((max_nnz_per_row + block_nnz_use - 1) // block_nnz_use, 1) - else: - max_segments_use = int(max_segments) - if max_segments_use <= 0: - raise ValueError("max_segments must be positive") - if max_segments_use < required_segments: - raise ValueError( - f"max_segments={max_segments_use} is too small; at least {required_segments} required" +def _spsm_block_nnz_for_row_length(max_nnz_per_row): + max_nnz_per_row = int(max_nnz_per_row) + if max_nnz_per_row <= 32: + return 32 + if max_nnz_per_row <= 64: + return 64 + if max_nnz_per_row <= 128: + return 128 + if max_nnz_per_row <= 256: + return 256 + if max_nnz_per_row <= 512: + return 512 + if max_nnz_per_row <= 1024: + return 1024 + if max_nnz_per_row <= 4096: + return 1024 + return 1024 + + +def _spsm_single_segment_block_nnz(limit): + return _spsm_block_nnz_for_row_length(limit) + + +def _bucketize_spsm_launch_groups(levels, indptr_like): + if not levels: + return [] + row_lengths = (indptr_like[1:] - indptr_like[:-1]).to(torch.int64) + bucket_limits = (32, 64, 96, 128, 192, 256, 384, 512, 768, 1024) + grouped = [] + for rows_lv in levels: + if rows_lv.numel() == 0: + continue + rows_i64 = rows_lv.to(torch.int64) + lv_lengths = row_lengths.index_select(0, rows_i64) + remaining = torch.ones(rows_lv.numel(), dtype=torch.bool, device=rows_lv.device) + zero_mask = remaining & (lv_lengths == 0) + if bool(torch.any(zero_mask).item()): + grouped.append( + { + "rows": rows_lv[zero_mask], + "diag_only": True, + } ) - return block_nnz_use, max_segments_use + remaining = remaining & (~zero_mask) + for limit in bucket_limits: + mask = remaining & (lv_lengths <= limit) + if bool(torch.any(mask).item()): + bucket_rows = rows_lv[mask] + # Keep single-segment buckets fully covered while staying on hardware-friendly powers of two. + block_nnz = _spsm_single_segment_block_nnz(limit) + grouped.append( + { + "rows": bucket_rows, + "block_nnz": block_nnz, + "max_segments": 1, + "diag_only": False, + "single_segment": True, + "staged": False, + } + ) + remaining = remaining & (~mask) + if bool(torch.any(remaining).item()): + bucket_rows = rows_lv[remaining] + bucket_lengths = lv_lengths[remaining] + bucket_max = int(bucket_lengths.max().item()) if bucket_lengths.numel() > 0 else 0 + block_nnz = _spsm_block_nnz_for_row_length(bucket_max) + max_segments = max((bucket_max + block_nnz - 1) // block_nnz, 1) + grouped.append( + { + "rows": bucket_rows, + "block_nnz": block_nnz, + "max_segments": max_segments, + "diag_only": False, + "single_segment": max_segments == 1, + "staged": bucket_max >= 256 or max_segments > 1, + } + ) + return grouped -def _auto_spsm_launch_config_from_row_lengths(row_lengths, block_nnz=None, max_segments=None): - if row_lengths.numel() == 0: +def _auto_spsm_launch_config(indptr, block_nnz=None, max_segments=None): + if indptr.numel() <= 1: max_nnz_per_row = 0 else: + row_lengths = indptr[1:] - indptr[:-1] max_nnz_per_row = int(row_lengths.max().item()) auto_block = block_nnz is None @@ -696,32 +736,46 @@ def _auto_rhs_block(n_rhs): return 64 -def _pack_spsm_rhs_work_buffer(rhs, alpha): - if rhs.ndim != 2: - raise ValueError("rhs must be 2D") - n_rows, n_rhs = int(rhs.shape[0]), int(rhs.shape[1]) - work = torch.empty_like(rhs, memory_format=torch.contiguous_format) - if n_rows == 0 or n_rhs == 0: - return work - - block_rows = 32 - block_rhs = _auto_rhs_block(n_rhs) - alpha_value = _alpha_to_host_scalar(alpha) - use_fp64 = rhs.dtype == torch.float64 - grid = (triton.cdiv(n_rows, block_rows), triton.cdiv(n_rhs, block_rhs)) - _spsm_pack_rhs_work_kernel[grid]( - rhs, - work, - n_rows, - n_rhs, - rhs.stride(0), - work.stride(0), - alpha_value, - BLOCK_ROWS=block_rows, - BLOCK_RHS=block_rhs, - USE_FP64_ACC=use_fp64, - ) - return work +def _choose_group_block_rhs( + n_rhs, + *, + block_nnz, + max_segments, + diag_only, + single_segment, + value_dtype, +): + base = _auto_rhs_block(n_rhs) + if diag_only: + return min(base, 32 if value_dtype == torch.float32 else 16) + if block_nnz <= 64 and single_segment: + return min(base, 32 if value_dtype == torch.float32 else 16) + if block_nnz <= 128 and single_segment: + return min(base, 16 if value_dtype == torch.float32 else 8) + if max_segments > 2 or block_nnz >= 512: + return min(base, 8) + if max_segments > 1 or block_nnz >= 256: + return min(base, 16 if value_dtype == torch.float32 else 8) + return base + + +def _choose_group_rhs_tile_groups(n_rhs, block_rhs_use, staged): + if not staged: + return 1 + if block_rhs_use <= 0: + return 1 + if n_rhs >= block_rhs_use * 4: + return 4 + if n_rhs >= block_rhs_use * 2: + return 2 + return 1 + + +def _prepare_spsm_rhs_work_buffer(rhs): + # Library-main solves through a dedicated RHS work buffer. In our current + # row-major NON_TRANS path that buffer already matches the final layout, so + # we keep a contiguous in-place work copy rather than doing an extra transpose. + return rhs.contiguous().clone() def _coo_to_csr_sorted_unique(data, row64, col64, n_rows, n_cols): @@ -758,26 +812,20 @@ def _prepare_spsm_csr_system(data, indices64, indptr64, n_rows, lower, unit_diag dep_data, dep_indices64, dep_ptr64 = _prepare_spsm_csr_dependency_view( sorted_data, sorted_indices64, indptr64, n_rows, lower=lower, row_ids=row_ids ) - levels = _build_spsm_levels(indptr64, sorted_indices64, n_rows, lower=lower) - default_block_nnz, default_max_segments = _auto_spsm_launch_config(indptr64) - tiled_block_nnz, tiled_max_segments = _auto_spsm_launch_config_from_row_lengths( - dep_ptr64[1:] - dep_ptr64[:-1] + levels = _build_spsm_levels(dep_ptr64, dep_indices64, n_rows, lower=lower) + launch_groups = _bucketize_spsm_launch_groups(levels, dep_ptr64) + diag = _prepare_spsm_diag( + sorted_data, sorted_indices64, indptr64, n_rows, unit_diagonal=unit_diagonal ) + default_block_nnz, default_max_segments = _auto_spsm_launch_config(dep_ptr64) return { - "kernel_data": sorted_data, - "kernel_indices32": sorted_indices64.to(torch.int32), "kernel_dep_data": dep_data, "kernel_dep_indices32": dep_indices64.to(torch.int32), "kernel_dep_ptr": _prepare_spsm_kernel_row_ptr(dep_ptr64), - "kernel_diag": _prepare_spsm_diag( - sorted_data, sorted_indices64, indptr64, n_rows, unit_diagonal=unit_diagonal - ), - "kernel_indptr": _prepare_spsm_kernel_row_ptr(indptr64), - "launch_groups": levels, + "kernel_inv_diag": _prepare_spsm_inv_diag(diag), + "launch_groups": launch_groups, "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, - "tiled_block_nnz": tiled_block_nnz, - "tiled_max_segments": tiled_max_segments, "lower_eff": bool(lower), } @@ -787,26 +835,20 @@ def _prepare_spsm_coo_system(data, row64, col64, n_rows, n_cols, lower, unit_dia dep_data, dep_indices64, dep_ptr64 = _prepare_spsm_csr_dependency_view( data_u, col_u64, row_ptr, n_rows, lower=lower ) - levels = _build_spsm_levels(row_ptr, col_u64, n_rows, lower=lower) - default_block_nnz, default_max_segments = _auto_spsm_launch_config(row_ptr) - tiled_block_nnz, tiled_max_segments = _auto_spsm_launch_config_from_row_lengths( - dep_ptr64[1:] - dep_ptr64[:-1] + levels = _build_spsm_levels(dep_ptr64, dep_indices64, n_rows, lower=lower) + launch_groups = _bucketize_spsm_launch_groups(levels, dep_ptr64) + diag = _prepare_spsm_diag( + data_u, col_u64, row_ptr, n_rows, unit_diagonal=unit_diagonal ) + default_block_nnz, default_max_segments = _auto_spsm_launch_config(dep_ptr64) return { - "kernel_data": data_u, - "kernel_cols32": col_u64.to(torch.int32), "kernel_dep_data": dep_data, "kernel_dep_indices32": dep_indices64.to(torch.int32), "kernel_dep_ptr": _prepare_spsm_kernel_row_ptr(dep_ptr64), - "kernel_diag": _prepare_spsm_diag( - data_u, col_u64, row_ptr, n_rows, unit_diagonal=unit_diagonal - ), - "kernel_row_ptr": _prepare_spsm_kernel_row_ptr(row_ptr), - "launch_groups": levels, + "kernel_inv_diag": _prepare_spsm_inv_diag(diag), + "launch_groups": launch_groups, "default_block_nnz": default_block_nnz, "default_max_segments": default_max_segments, - "tiled_block_nnz": tiled_block_nnz, - "tiled_max_segments": tiled_max_segments, "lower_eff": bool(lower), } @@ -855,12 +897,12 @@ def _run_spsm_csr_core( data, indices32, indptr, - diag, + inv_diag, rhs, n_rows, *, + alpha=1.0, lower=True, - unit_diagonal=False, block_nnz=None, max_segments=None, block_rhs=None, @@ -874,180 +916,128 @@ def _run_spsm_csr_core( if rhs.shape[0] != n_rows: raise ValueError("rhs first dim must equal n_rows") n_rhs = int(rhs.shape[1]) - x = torch.zeros_like(rhs) + rhs_work = _prepare_spsm_rhs_work_buffer(rhs) if n_rows == 0 or n_rhs == 0: - return x + return rhs_work if launch_groups is None: levels = _build_spsm_levels(indptr, indices32, n_rows, lower=lower) - launch_groups = levels + launch_groups = _bucketize_spsm_launch_groups(levels, indptr) if block_nnz_use is None or max_segments_use is None: block_nnz_use, max_segments_use = _auto_spsm_launch_config( indptr, block_nnz=block_nnz, max_segments=max_segments ) - block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) - if block_rhs_use <= 0: - raise ValueError("block_rhs must be positive") - use_fp64 = data.dtype == torch.float64 - diag_eps = 1e-12 if use_fp64 else 1e-6 - for rows_lv in launch_groups: + for group in launch_groups: + if isinstance(group, dict): + rows_lv = group["rows"] + diag_only = bool(group.get("diag_only", False)) + single_segment = bool(group.get("single_segment", False)) + staged = bool(group.get("staged", False)) + block_nnz_lv = int(group.get("block_nnz", block_nnz_use)) + max_segments_lv = int(group.get("max_segments", max_segments_use)) + else: + rows_lv = group + diag_only = False + single_segment = False + staged = False + block_nnz_lv = block_nnz_use + max_segments_lv = max_segments_use n_lv = rows_lv.numel() if n_lv == 0: continue + block_rhs_use = ( + int(block_rhs) + if block_rhs is not None + else _choose_group_block_rhs( + n_rhs, + block_nnz=block_nnz_lv, + max_segments=max_segments_lv, + diag_only=diag_only, + single_segment=single_segment, + value_dtype=data.dtype, + ) + ) + if block_rhs_use <= 0: + raise ValueError("block_rhs must be positive") + rhs_tile_groups = _choose_group_rhs_tile_groups(n_rhs, block_rhs_use, staged) + if diag_only: + grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use)) + _spsm_csr_diag_only_kernel_real[grid]( + inv_diag, + rhs_work, + rhs_work, + rows_lv, + n_level_rows=n_lv, + n_rhs=n_rhs, + stride_b0=rhs_work.stride(0), + stride_x0=rhs_work.stride(0), + alpha=alpha, + BLOCK_RHS=block_rhs_use, + USE_FP64_ACC=use_fp64, + ) + continue + if single_segment: + grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use)) + _spsm_csr_level_kernel_single_segment_real[grid]( + data, + indices32, + indptr, + inv_diag, + rhs_work, + rhs_work, + rows_lv, + n_level_rows=n_lv, + n_rhs=n_rhs, + stride_b0=rhs_work.stride(0), + stride_x0=rhs_work.stride(0), + alpha=alpha, + BLOCK_NNZ=block_nnz_lv, + BLOCK_RHS=block_rhs_use, + USE_FP64_ACC=use_fp64, + ) + continue + if staged: + grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use * rhs_tile_groups)) + _spsm_csr_level_kernel_staged_real[grid]( + data, + indices32, + indptr, + inv_diag, + rhs_work, + rows_lv, + n_level_rows=n_lv, + n_rhs=n_rhs, + stride_work0=rhs_work.stride(0), + alpha=alpha, + BLOCK_NNZ=block_nnz_lv, + BLOCK_RHS=block_rhs_use, + RHS_TILE_GROUPS=rhs_tile_groups, + MAX_SEGMENTS=max_segments_lv, + USE_FP64_ACC=use_fp64, + ) + continue grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use)) _spsm_csr_level_kernel_real[grid]( data, indices32, indptr, - diag, - rhs, - x, - rows_lv, - n_level_rows=n_lv, - n_rhs=n_rhs, - stride_b0=rhs.stride(0), - stride_x0=x.stride(0), - BLOCK_NNZ=block_nnz_use, - BLOCK_RHS=block_rhs_use, - MAX_SEGMENTS=max_segments_use, - LOWER=lower, - USE_FP64_ACC=use_fp64, - DIAG_EPS=diag_eps, - ) - return x - - -def _run_spsm_csr_tiled_core( - data, - indices32, - dep_ptr, - diag, - rhs, - n_rows, - *, - alpha=1.0, - block_nnz=None, - max_segments=None, - block_rhs=None, - block_nnz_use=None, - max_segments_use=None, -): - if rhs.ndim != 2: - raise ValueError("rhs must be 2D") - rhs = rhs.contiguous() - if rhs.shape[0] != n_rows: - raise ValueError("rhs first dim must equal n_rows") - n_rhs = int(rhs.shape[1]) - y = _pack_spsm_rhs_work_buffer(rhs, alpha) - if n_rows == 0 or n_rhs == 0: - return y - - if block_nnz_use is None or max_segments_use is None: - block_nnz_use, max_segments_use = _auto_spsm_launch_config_from_row_lengths( - dep_ptr[1:] - dep_ptr[:-1], - block_nnz=block_nnz, - max_segments=max_segments, - ) - block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) - if block_rhs_use <= 0: - raise ValueError("block_rhs must be positive") - - use_fp64 = data.dtype == torch.float64 - diag_eps = 1e-12 if use_fp64 else 1e-6 - rhs_tiles = triton.cdiv(n_rhs, block_rhs_use) - ready = torch.zeros((rhs_tiles, n_rows), dtype=torch.int32, device=rhs.device) - - grid = (n_rows, rhs_tiles) - _spsm_csr_tiled_kernel_real[grid]( - data, - indices32, - dep_ptr, - diag, - y, - ready, - n_rows, - n_rhs, - y.stride(0), - ready.stride(0), - BLOCK_NNZ=block_nnz_use, - BLOCK_RHS=block_rhs_use, - MAX_SEGMENTS=max_segments_use, - USE_FP64_ACC=use_fp64, - DIAG_EPS=diag_eps, - ) - return y - - -def _run_spsm_coo_core( - data, - cols32, - row_ptr, - diag, - rhs, - n_rows, - *, - lower=True, - unit_diagonal=False, - block_nnz=None, - max_segments=None, - block_rhs=None, - launch_groups=None, - block_nnz_use=None, - max_segments_use=None, -): - if rhs.ndim != 2: - raise ValueError("rhs must be 2D") - rhs = rhs.contiguous() - if rhs.shape[0] != n_rows: - raise ValueError("rhs first dim must equal n_rows") - n_rhs = int(rhs.shape[1]) - x = torch.zeros_like(rhs) - if n_rows == 0 or n_rhs == 0: - return x - - if launch_groups is None: - levels = _build_spsm_levels(row_ptr, cols32, n_rows, lower=lower) - launch_groups = levels - if block_nnz_use is None or max_segments_use is None: - block_nnz_use, max_segments_use = _auto_spsm_launch_config( - row_ptr, block_nnz=block_nnz, max_segments=max_segments - ) - block_rhs_use = _auto_rhs_block(n_rhs) if block_rhs is None else int(block_rhs) - if block_rhs_use <= 0: - raise ValueError("block_rhs must be positive") - - use_fp64 = data.dtype == torch.float64 - diag_eps = 1e-12 if use_fp64 else 1e-6 - - for rows_lv in launch_groups: - n_lv = rows_lv.numel() - if n_lv == 0: - continue - grid = (n_lv, triton.cdiv(n_rhs, block_rhs_use)) - _spsm_coo_level_kernel_real[grid]( - data, - row_ptr, - cols32, - diag, - rhs, - x, + inv_diag, + rhs_work, + rhs_work, rows_lv, n_level_rows=n_lv, n_rhs=n_rhs, - stride_b0=rhs.stride(0), - stride_x0=x.stride(0), - BLOCK_NNZ=block_nnz_use, + stride_b0=rhs_work.stride(0), + stride_x0=rhs_work.stride(0), + alpha=alpha, + BLOCK_NNZ=block_nnz_lv, BLOCK_RHS=block_rhs_use, - MAX_SEGMENTS=max_segments_use, - LOWER=lower, + MAX_SEGMENTS=max_segments_lv, USE_FP64_ACC=use_fp64, - DIAG_EPS=diag_eps, ) - return x - - + return rhs_work def flagsparse_spsm_csr( @@ -1068,38 +1058,23 @@ def flagsparse_spsm_csr( data, B, n_rows, _n_cols, solve_plan = _resolve_spsm_csr_runtime( data, indices, indptr, B, shape, lower, unit_diagonal, opA, opB, major ) - use_tiled_kernel = _should_use_spsm_tiled_kernel("csr", B, data.dtype) alpha_value = 1.0 if _alpha_is_one(alpha) else _alpha_to_host_scalar(alpha) - rhs = B if use_tiled_kernel else (B if alpha_value == 1.0 else torch.as_tensor(alpha, dtype=B.dtype, device=B.device) * B) if return_time: torch.cuda.synchronize() t0 = time.perf_counter() - if use_tiled_kernel: - x = _run_spsm_csr_tiled_core( - solve_plan["kernel_dep_data"], - solve_plan["kernel_dep_indices32"], - solve_plan["kernel_dep_ptr"], - solve_plan["kernel_diag"], - rhs, - n_rows, - alpha=alpha_value, - block_nnz_use=solve_plan["tiled_block_nnz"], - max_segments_use=solve_plan["tiled_max_segments"], - ) - else: - x = _run_spsm_csr_core( - solve_plan["kernel_data"], - solve_plan["kernel_indices32"], - solve_plan["kernel_indptr"], - solve_plan["kernel_diag"], - rhs, - n_rows, - lower=solve_plan["lower_eff"], - unit_diagonal=unit_diagonal, - launch_groups=solve_plan["launch_groups"], - block_nnz_use=solve_plan["default_block_nnz"], - max_segments_use=solve_plan["default_max_segments"], - ) + x = _run_spsm_csr_core( + solve_plan["kernel_dep_data"], + solve_plan["kernel_dep_indices32"], + solve_plan["kernel_dep_ptr"], + solve_plan["kernel_inv_diag"], + B, + n_rows, + alpha=alpha_value, + lower=solve_plan["lower_eff"], + launch_groups=solve_plan["launch_groups"], + block_nnz_use=solve_plan["default_block_nnz"], + max_segments_use=solve_plan["default_max_segments"], + ) if return_time: torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 @@ -1131,38 +1106,23 @@ def flagsparse_spsm_coo( data, B, n_rows, _n_cols, solve_plan = _resolve_spsm_coo_runtime( data, row, col, B, shape, lower, unit_diagonal, opA, opB, major ) - use_tiled_kernel = _should_use_spsm_tiled_kernel("coo", B, data.dtype) alpha_value = 1.0 if _alpha_is_one(alpha) else _alpha_to_host_scalar(alpha) - rhs = B if use_tiled_kernel else (B if alpha_value == 1.0 else torch.as_tensor(alpha, dtype=B.dtype, device=B.device) * B) if return_time: torch.cuda.synchronize() t0 = time.perf_counter() - if use_tiled_kernel: - x = _run_spsm_csr_tiled_core( - solve_plan["kernel_dep_data"], - solve_plan["kernel_dep_indices32"], - solve_plan["kernel_dep_ptr"], - solve_plan["kernel_diag"], - rhs, - n_rows, - alpha=alpha_value, - block_nnz_use=solve_plan["tiled_block_nnz"], - max_segments_use=solve_plan["tiled_max_segments"], - ) - else: - x = _run_spsm_coo_core( - solve_plan["kernel_data"], - solve_plan["kernel_cols32"], - solve_plan["kernel_row_ptr"], - solve_plan["kernel_diag"], - rhs, - n_rows, - lower=solve_plan["lower_eff"], - unit_diagonal=unit_diagonal, - launch_groups=solve_plan["launch_groups"], - block_nnz_use=solve_plan["default_block_nnz"], - max_segments_use=solve_plan["default_max_segments"], - ) + x = _run_spsm_csr_core( + solve_plan["kernel_dep_data"], + solve_plan["kernel_dep_indices32"], + solve_plan["kernel_dep_ptr"], + solve_plan["kernel_inv_diag"], + B, + n_rows, + alpha=alpha_value, + lower=solve_plan["lower_eff"], + launch_groups=solve_plan["launch_groups"], + block_nnz_use=solve_plan["default_block_nnz"], + max_segments_use=solve_plan["default_max_segments"], + ) if return_time: torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - t0) * 1000.0 diff --git a/src/flagsparse/sparse_operations/spsv.py b/src/flagsparse/sparse_operations/spsv.py index 9783cbe..e9e4f6e 100644 --- a/src/flagsparse/sparse_operations/spsv.py +++ b/src/flagsparse/sparse_operations/spsv.py @@ -58,6 +58,35 @@ def _clear_spsv_csr_preprocess_cache(): _SPSV_CSR_PREPROCESS_CACHE.clear() +def _as_strided_contiguous(tensor): + if tensor is None: + return None + if tensor.layout != torch.strided: + out = torch.empty(tensor.shape, dtype=tensor.dtype, device=tensor.device) + out.copy_(tensor) + return out + return tensor.contiguous() + + +def _complex_interleaved_view(tensor): + tensor_strided = _as_strided_contiguous(tensor) + return torch.view_as_real(tensor_strided).reshape(-1).contiguous() + + +def _attach_spsv_complex_plan_views(plan): + kernel_data = plan.get("kernel_data") + if kernel_data is None or not torch.is_complex(kernel_data): + return plan + plan["kernel_data_ri"] = _complex_interleaved_view(kernel_data) + transpose_diag = plan.get("transpose_diag") + if transpose_diag is not None and torch.is_complex(transpose_diag): + plan["transpose_diag_ri"] = _complex_interleaved_view(transpose_diag) + cw_diag = plan.get("cw_diag") + if cw_diag is not None and torch.is_complex(cw_diag): + plan["cw_diag_ri"] = _complex_interleaved_view(cw_diag) + return plan + + def _validate_spsv_non_trans_combo(data_dtype, index_dtype, fmt_name): """Validate NON_TRANS support matrix and keep error messages explicit.""" if (data_dtype, index_dtype) in SPSV_NON_TRANS_SUPPORTED_COMBOS: @@ -340,7 +369,7 @@ def _sort_csr_rows(data, indices64, indptr64, n_rows, n_cols, lower=True): return data[order], indices64[order], indptr64 -def _prepare_spsv_nontrans_cw_metadata(data, indices64, indptr64, n_rows, lower, unit_diagonal=False): +def _prepare_spsv_nontrans_cw_metadata(data, indices64, indptr64, n_rows, unit_diagonal=False): diag = torch.ones(n_rows, dtype=data.dtype, device=data.device) if n_rows == 0: return diag @@ -356,17 +385,61 @@ def _prepare_spsv_nontrans_cw_metadata(data, indices64, indptr64, n_rows, lower, return diag +def _cw_rhs_bucket(n_rhs): + if n_rhs <= 1: + return 1 + if n_rhs <= 2: + return 2 + if n_rhs <= 4: + return 4 + if n_rhs <= 8: + return 8 + if n_rhs <= 16: + return 16 + return 32 + + +def _snap_cw_worker_count(target, n_rows): + if n_rows <= 0: + return 1 + target = max(1, min(int(target), int(n_rows))) + snapped = 1 + tier = 1 + while tier < target and tier < 4096: + tier *= 2 + if tier <= target: + snapped = tier + return int(max(1, min(snapped, int(n_rows)))) + + def _cw_worker_count(n_rows, max_frontier, avg_nnz_per_row, n_rhs): if n_rows <= 0: return 1 + rhs_bucket = _cw_rhs_bucket(n_rhs) target = max(256, min(n_rows, 4096)) if max_frontier > 0: - target = min(target, max(256, min(n_rows, max_frontier * 8))) + target = min(target, max(128, min(n_rows, max_frontier * 8))) if avg_nnz_per_row > 2048: target = max(128, target // 2) - if n_rhs >= 16: - target = max(128, target // 2) - return int(max(1, min(n_rows, target))) + if rhs_bucket >= 16: + target = max(64, target // 4) + elif rhs_bucket >= 8: + target = max(64, target // 2) + elif rhs_bucket >= 4: + target = max(128, (target * 3) // 4) + return _snap_cw_worker_count(target, n_rows) + + +def _resolve_cw_worker_count(n_rows, matrix_stats, n_rhs, cached_worker_count=None): + rhs_bucket = _cw_rhs_bucket(n_rhs) + if cached_worker_count is not None and rhs_bucket == 1: + return int(max(1, min(int(cached_worker_count), int(max(n_rows, 1))))) + return _cw_worker_count( + n_rows, + int(matrix_stats.get("max_frontier", n_rows)), + float(matrix_stats.get("avg_nnz_per_row", 0.0)), + rhs_bucket, + ) def _spsv_level_stats(levels, n_rows): @@ -457,8 +530,12 @@ def _score_nontrans_cw(matrix_stats, n_rhs, complex_mode): score += 1.5 if matrix_stats["avg_nnz_per_row"] <= 128.0: score += 1.5 + elif matrix_stats["avg_nnz_per_row"] <= 160.0: + score += 0.5 if matrix_stats["max_nnz_per_row"] > 4096: score -= 2.5 + elif matrix_stats["max_nnz_per_row"] > 1536: + score -= 1.5 if n_rhs >= 8: score -= 3.0 elif n_rhs >= 4: @@ -510,19 +587,19 @@ def _score_transpose_cw(matrix_stats, n_rhs, complex_mode): def _nontrans_cw_eligible(matrix_stats, n_rhs, complex_mode): if complex_mode: return False - if n_rhs > 2: + if n_rhs != 1: return False - if int(matrix_stats.get("num_levels", 0)) < 256: + if int(matrix_stats.get("num_levels", 0)) < 1024: return False - if float(matrix_stats.get("avg_frontier", 0.0)) > 24.0: + if float(matrix_stats.get("avg_frontier", 0.0)) > 16.0: return False - if float(matrix_stats.get("frontier_ratio", 1.0)) > 0.08: + if float(matrix_stats.get("frontier_ratio", 1.0)) > 0.05: return False avg_nnz = float(matrix_stats.get("avg_nnz_per_row", 0.0)) max_nnz = int(matrix_stats.get("max_nnz_per_row", 0)) - if avg_nnz <= 0.0 or avg_nnz > 256.0: + if avg_nnz <= 0.0 or avg_nnz > 160.0: return False - if max_nnz > 2048: + if max_nnz > 1536: return False return True @@ -554,7 +631,7 @@ def _select_nontrans_route(matrix_stats, n_rhs, complex_mode): return "csr_levels" levels_score = _score_nontrans_levels(matrix_stats, n_rhs, complex_mode) cw_score = _score_nontrans_cw(matrix_stats, n_rhs, complex_mode) - return "csr_cw" if cw_score >= (levels_score + 1.0) else "csr_levels" + return "csr_cw" if cw_score >= (levels_score + 2.0) else "csr_levels" def _select_transpose_route(matrix_stats, n_rhs, complex_mode, trans_mode): @@ -573,7 +650,6 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t levels_plan = { "solve_kind": "csr_levels", "kernel_data": data, - "kernel_indices64": indices64, "kernel_indices32": indices64.to(torch.int32), "kernel_indptr64": indptr64, "lower_eff": lower, @@ -590,6 +666,7 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "route_name": "csr_levels", "alt_plan": None, } + _attach_spsv_complex_plan_views(levels_plan) if not SPSV_ENABLE_CSR_CW: return levels_plan data_sorted, indices_sorted64, indptr_sorted64 = _sort_csr_rows( @@ -599,7 +676,6 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t cw_plan = { "solve_kind": "csr_cw", "kernel_data": data_sorted, - "kernel_indices64": indices_sorted64, "kernel_indices32": indices_sorted64.to(torch.int32), "kernel_indptr64": indptr_sorted64, "lower_eff": lower, @@ -608,7 +684,7 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "default_max_segments": default_max_segments, "transpose_conjugate": False, "cw_diag": _prepare_spsv_nontrans_cw_metadata( - data_sorted, indices_sorted64, indptr_sorted64, n_rows, lower, unit_diagonal=unit_diagonal + data_sorted, indices_sorted64, indptr_sorted64, n_rows, unit_diagonal=unit_diagonal ), "cw_worker_count": _cw_worker_count( n_rows, matrix_stats["max_frontier"], matrix_stats["avg_nnz_per_row"], 1 @@ -617,6 +693,7 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "route_name": "csr_cw", "alt_plan": None, } + _attach_spsv_complex_plan_views(cw_plan) levels_plan["alt_plan"] = cw_plan return levels_plan @@ -628,7 +705,6 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t push_plan = { "solve_kind": "transpose_push", "kernel_data": data, - "kernel_indices64": indices64, "kernel_indices32": indices64.to(torch.int32), "kernel_indptr64": indptr64, "lower_eff": lower, @@ -645,6 +721,7 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "route_name": "transpose_push", "alt_plan": None, } + _attach_spsv_complex_plan_views(push_plan) if not SPSV_ENABLE_TRANSPOSE_CW: return push_plan @@ -657,7 +734,6 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t cw_plan = { "solve_kind": "transpose_cw", "kernel_data": data, - "kernel_indices64": indices64, "kernel_indices32": indices64.to(torch.int32), "kernel_indptr64": indptr64, "lower_eff": lower, @@ -674,6 +750,7 @@ def _prepare_spsv_csr_system(data, indices64, indptr64, n_rows, n_cols, lower, t "route_name": "transpose_cw", "alt_plan": None, } + _attach_spsv_complex_plan_views(cw_plan) push_plan["alt_plan"] = cw_plan return push_plan @@ -921,7 +998,6 @@ def _spsv_csr_cw_kernel( BLOCK_NNZ: tl.constexpr, MAX_SEGMENTS: tl.constexpr, LOWER: tl.constexpr, - UNIT_DIAG: tl.constexpr, DIAG_EPS: tl.constexpr, ): row = tl.atomic_add(row_counter_ptr, 1) @@ -1558,6 +1634,7 @@ def _triton_spsv_csr_vector_complex( levels=None, block_nnz_use=None, max_segments_use=None, + data_ri_in=None, ): x = torch.zeros_like(b_vec) if n_rows == 0: @@ -1571,13 +1648,7 @@ def _triton_spsv_csr_vector_complex( # Some PyTorch builds return CSR values with a non-strided layout wrapper. # Materialize a plain 1D strided buffer before splitting into real/imag parts. - if data.layout != torch.strided: - data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) - data_strided.copy_(data) - else: - data_strided = data.contiguous() - - data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() + data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -1620,7 +1691,6 @@ def _triton_spsv_csr_cw_vector( b_vec, n_rows, lower=True, - unit_diagonal=False, block_nnz=None, max_segments=None, diag_eps=1e-12, @@ -1646,12 +1716,7 @@ def _triton_spsv_csr_cw_vector( matrix_stats = matrix_stats or {} block_rhs = _choose_spsv_block_rhs(n_rhs, matrix_stats, complex_mode=False) if worker_count is None: - worker_count = _cw_worker_count( - n_rows, - int(matrix_stats.get("max_frontier", n_rows)), - float(matrix_stats.get("avg_nnz_per_row", 0.0)), - n_rhs, - ) + worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, n_rhs) grid = (worker_count,) _spsv_csr_cw_kernel[grid]( data, @@ -1670,7 +1735,6 @@ def _triton_spsv_csr_cw_vector( BLOCK_NNZ=block_nnz_use, MAX_SEGMENTS=max_segments_use, LOWER=lower, - UNIT_DIAG=unit_diagonal, DIAG_EPS=diag_eps, ) return x.squeeze(1) if b_vec.ndim == 1 else x @@ -1692,17 +1756,20 @@ def _triton_spsv_csr_cw_vector_complex( max_segments_use=None, worker_count=None, matrix_stats=None, + data_ri_in=None, + diag_ri_in=None, ): if b_vec.ndim != 1: + shared_b = b_vec if b_vec.is_contiguous() else b_vec.contiguous() cols = [] - for j in range(b_vec.shape[1]): + for bj in torch.unbind(shared_b, dim=1): cols.append( _triton_spsv_csr_cw_vector_complex( data, indices, indptr, diag, - b_vec[:, j].contiguous(), + bj, n_rows, lower=lower, unit_diagonal=unit_diagonal, @@ -1713,6 +1780,8 @@ def _triton_spsv_csr_cw_vector_complex( max_segments_use=max_segments_use, worker_count=worker_count, matrix_stats=matrix_stats, + data_ri_in=data_ri_in, + diag_ri_in=diag_ri_in, ) ) return torch.stack(cols, dim=1) @@ -1727,19 +1796,8 @@ def _triton_spsv_csr_cw_vector_complex( indptr, block_nnz=block_nnz, max_segments=max_segments ) - if data.layout != torch.strided: - data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) - data_strided.copy_(data) - else: - data_strided = data.contiguous() - if diag.layout != torch.strided: - diag_strided = torch.empty(diag.shape, dtype=diag.dtype, device=diag.device) - diag_strided.copy_(diag) - else: - diag_strided = diag.contiguous() - - data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() - diag_ri = torch.view_as_real(diag_strided).reshape(-1).contiguous() + data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) + diag_ri = diag_ri_in if diag_ri_in is not None else _complex_interleaved_view(diag) b_ri = torch.view_as_real(b_vec.contiguous()).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -1747,12 +1805,7 @@ def _triton_spsv_csr_cw_vector_complex( if worker_count is None: matrix_stats = matrix_stats or {} - worker_count = _cw_worker_count( - n_rows, - int(matrix_stats.get("max_frontier", n_rows)), - float(matrix_stats.get("avg_nnz_per_row", 0.0)), - 1, - ) + worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, 1) grid = (worker_count,) _spsv_csr_cw_kernel_complex[grid]( data_ri, @@ -1838,6 +1891,7 @@ def _triton_spsv_csr_transpose_push_vector_complex( launch_groups=None, block_nnz_use=None, max_segments_use=None, + data_ri_in=None, ): x = torch.zeros_like(b_vec) if n_rows == 0: @@ -1850,14 +1904,8 @@ def _triton_spsv_csr_transpose_push_vector_complex( indptr, block_nnz=block_nnz, max_segments=max_segments ) - if data.layout != torch.strided: - data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) - data_strided.copy_(data) - else: - data_strided = data.contiguous() - residual_work = b_vec.contiguous().clone() - data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() + data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) residual_ri = torch.view_as_real(residual_work).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -1923,12 +1971,7 @@ def _triton_spsv_csr_transpose_cw_vector( ) if worker_count is None: matrix_stats = matrix_stats or {} - worker_count = _cw_worker_count( - n_rows, - int(matrix_stats.get("max_frontier", n_rows)), - float(matrix_stats.get("avg_nnz_per_row", 0.0)), - 1, - ) + worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, 1) grid = (worker_count,) _spsv_csr_transpose_cw_kernel[grid]( data, @@ -1967,6 +2010,8 @@ def _triton_spsv_csr_transpose_cw_vector_complex( max_segments_use=None, worker_count=None, matrix_stats=None, + data_ri_in=None, + diag_ri_in=None, ): x = torch.zeros_like(b_vec) if n_rows == 0: @@ -1976,22 +2021,11 @@ def _triton_spsv_csr_transpose_cw_vector_complex( indptr, block_nnz=block_nnz, max_segments=max_segments ) - if data.layout != torch.strided: - data_strided = torch.empty(data.shape, dtype=data.dtype, device=data.device) - data_strided.copy_(data) - else: - data_strided = data.contiguous() - if diag.layout != torch.strided: - diag_strided = torch.empty(diag.shape, dtype=diag.dtype, device=diag.device) - diag_strided.copy_(diag) - else: - diag_strided = diag.contiguous() - residual_work = b_vec.contiguous().clone() indegree = indegree_init.clone() row_counter = torch.zeros(1, dtype=torch.int32, device=b_vec.device) - data_ri = torch.view_as_real(data_strided).reshape(-1).contiguous() - diag_ri = torch.view_as_real(diag_strided).reshape(-1).contiguous() + data_ri = data_ri_in if data_ri_in is not None else _complex_interleaved_view(data) + diag_ri = diag_ri_in if diag_ri_in is not None else _complex_interleaved_view(diag) residual_ri = torch.view_as_real(residual_work).reshape(-1).contiguous() component_dtype = _component_dtype_for_complex(data.dtype) use_fp64 = component_dtype == torch.float64 @@ -2003,12 +2037,7 @@ def _triton_spsv_csr_transpose_cw_vector_complex( if worker_count is None: matrix_stats = matrix_stats or {} - worker_count = _cw_worker_count( - n_rows, - int(matrix_stats.get("max_frontier", n_rows)), - float(matrix_stats.get("avg_nnz_per_row", 0.0)), - 1, - ) + worker_count = _resolve_cw_worker_count(n_rows, matrix_stats, 1) grid = (worker_count,) _spsv_csr_transpose_cw_kernel_complex[grid]( data_ri, @@ -2049,7 +2078,7 @@ def _choose_transpose_family_launch_config(indptr, block_nnz=None, max_segments= return cand, req -def _prepare_spsv_coo_inputs(data, row, col, b, shape, transpose=False): +def _prepare_spsv_coo_inputs(data, row, col, b, shape): if not all(torch.is_tensor(t) for t in (data, row, col, b)): raise TypeError("data, row, col, b must all be torch.Tensor") if not all(t.is_cuda for t in (data, row, col, b)): @@ -2269,7 +2298,6 @@ def flagsparse_spsv_csr( ) solve_kind = solve_plan["solve_kind"] kernel_data = solve_plan["kernel_data"] - kernel_indices64 = solve_plan["kernel_indices64"] kernel_indices32 = solve_plan["kernel_indices32"] kernel_indptr64 = solve_plan["kernel_indptr64"] lower_eff = solve_plan["lower_eff"] @@ -2354,12 +2382,26 @@ def flagsparse_spsv_csr( rhs_cols = 1 if b_in.ndim == 1 else int(b_in.shape[1]) matrix_stats_use = dict(matrix_stats) if solve_kind in ("csr_cw", "transpose_cw"): - worker_count_use = _cw_worker_count( + worker_count_use = _resolve_cw_worker_count( n_rows, - int(matrix_stats_use.get("max_frontier", cw_worker_count if cw_worker_count is not None else n_rows)), - float(matrix_stats_use.get("avg_nnz_per_row", default_block_nnz)), + matrix_stats_use, rhs_cols, + cached_worker_count=cw_worker_count, ) + complex_kernel_data_ri = None + complex_transpose_diag_ri = None + complex_cw_diag_ri = None + if torch.is_complex(data_in): + if compute_dtype == solve_plan["kernel_data"].dtype: + complex_kernel_data_ri = solve_plan.get("kernel_data_ri") + complex_transpose_diag_ri = solve_plan.get("transpose_diag_ri") + complex_cw_diag_ri = solve_plan.get("cw_diag_ri") + if complex_kernel_data_ri is None: + complex_kernel_data_ri = _complex_interleaved_view(data_in) + if transpose_diag_in is not None and complex_transpose_diag_ri is None: + complex_transpose_diag_ri = _complex_interleaved_view(transpose_diag_in) + if cw_diag_in is not None and complex_cw_diag_ri is None: + complex_cw_diag_ri = _complex_interleaved_view(cw_diag_in) if b_in.ndim == 1: if torch.is_complex(data_in): if solve_kind == "transpose_push": @@ -2378,6 +2420,7 @@ def flagsparse_spsv_csr( launch_groups=launch_groups, block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, + data_ri_in=complex_kernel_data_ri, ) elif solve_kind == "transpose_cw": x = vec_complex( @@ -2398,6 +2441,8 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, worker_count=worker_count_use, matrix_stats=matrix_stats_use, + data_ri_in=complex_kernel_data_ri, + diag_ri_in=complex_transpose_diag_ri, ) elif solve_kind == "csr_cw": x = vec_complex( @@ -2416,6 +2461,8 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, worker_count=worker_count_use, matrix_stats=matrix_stats_use, + data_ri_in=complex_kernel_data_ri, + diag_ri_in=complex_cw_diag_ri, ) else: x = vec_complex( @@ -2432,6 +2479,7 @@ def flagsparse_spsv_csr( levels=launch_groups, block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, + data_ri_in=complex_kernel_data_ri, ) else: if solve_kind == "transpose_push": @@ -2478,7 +2526,6 @@ def flagsparse_spsv_csr( b_in, n_rows, lower=lower_eff, - unit_diagonal=unit_diagonal, block_nnz=block_nnz, max_segments=max_segments, diag_eps=diag_eps, @@ -2504,9 +2551,9 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, ) else: + b_cols = b_in if b_in.is_contiguous() else b_in.contiguous() cols = [] - for j in range(b_in.shape[1]): - bj = b_in[:, j].contiguous() + for bj in torch.unbind(b_cols, dim=1): if torch.is_complex(data_in): if solve_kind == "transpose_push": cols.append( @@ -2525,6 +2572,7 @@ def flagsparse_spsv_csr( launch_groups=launch_groups, block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, + data_ri_in=complex_kernel_data_ri, ) ) elif solve_kind == "transpose_cw": @@ -2547,6 +2595,8 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, worker_count=worker_count_use, matrix_stats=matrix_stats_use, + data_ri_in=complex_kernel_data_ri, + diag_ri_in=complex_transpose_diag_ri, ) ) elif solve_kind == "csr_cw": @@ -2567,6 +2617,8 @@ def flagsparse_spsv_csr( max_segments_use=max_segments_use, worker_count=worker_count_use, matrix_stats=matrix_stats_use, + data_ri_in=complex_kernel_data_ri, + diag_ri_in=complex_cw_diag_ri, ) ) else: @@ -2585,6 +2637,7 @@ def flagsparse_spsv_csr( levels=launch_groups, block_nnz_use=block_nnz_use, max_segments_use=max_segments_use, + data_ri_in=complex_kernel_data_ri, ) ) else: @@ -2637,7 +2690,6 @@ def flagsparse_spsv_csr( bj, n_rows, lower=lower_eff, - unit_diagonal=unit_diagonal, block_nnz=block_nnz, max_segments=max_segments, diag_eps=diag_eps, @@ -2741,7 +2793,7 @@ def flagsparse_spsv_coo( - complex dtypes and TRANS/CONJ always route through the CSR implementation """ data, row64, col64, b, n_rows, n_cols = _prepare_spsv_coo_inputs( - data, row, col, b, shape, transpose=transpose + data, row, col, b, shape ) if n_rows != n_cols: raise ValueError(f"A must be square, got shape={shape}") @@ -2820,14 +2872,15 @@ def flagsparse_spsv_coo( max_segments_use=max_segments_use, ) else: + b_cols = b_in if b_in.is_contiguous() else b_in.contiguous() cols_out = [] - for j in range(b_in.shape[1]): + for bj in torch.unbind(b_cols, dim=1): cols_out.append( _triton_spsv_coo_vector( data_in, kernel_cols, row_ptr, - b_in[:, j].contiguous(), + bj, n_rows, lower=lower, unit_diagonal=unit_diagonal, diff --git a/tests/test_spsm.py b/tests/test_spsm.py index ab15232..d03637e 100644 --- a/tests/test_spsm.py +++ b/tests/test_spsm.py @@ -66,6 +66,31 @@ def _safe_ratio(other_ms, triton_ms): return other_ms / triton_ms +def _csv_export_row_spsm(row): + return { + "matrix": row.get("matrix"), + "value_dtype": row.get("value_dtype"), + "index_dtype": row.get("index_dtype"), + "format": row.get("format"), + "n_rows": row.get("n_rows"), + "n_cols": row.get("n_cols"), + "nnz": row.get("nnz"), + "n_rhs": row.get("n_rhs"), + "triton_ms": row.get("triton_ms"), + "cusparse_ms": row.get("cusparse_ms"), + "pytorch_ms": row.get("pytorch_ms"), + "cusparse/triton": row.get("cusparse/triton"), + "pytorch/triton": row.get("pytorch/triton"), + "status": row.get("status"), + "err_ref": row.get("err_ref"), + "err_res": row.get("err_res"), + "err_pt": row.get("err_pt"), + "err_cu": row.get("err_cu"), + "pytorch_reason": row.get("pytorch_reason"), + "error": row.get("error"), + } + + def _parse_csv_tokens(raw): return [tok.strip() for tok in str(raw).split(",") if tok.strip()] @@ -405,9 +430,7 @@ def _run_one_spsm_case(data, indices, indptr, shape, value_dtype, index_dtype, n B, shape, ) - total_ms = analysis_ms + solve_ms if analysis_ms is not None and solve_ms is not None else None - - X_cu, cusparse_ms, cusparse_reason = _benchmark_cusparse_reference( + X_cu, cusparse_ms, _cusparse_reason = _benchmark_cusparse_reference( data, row, col, indptr, B, shape, fmt ) X_pt, pytorch_ms, pt_backend, pytorch_reason = _benchmark_pytorch_reference( @@ -448,12 +471,9 @@ def _run_one_spsm_case(data, indices, indptr, shape, value_dtype, index_dtype, n "n_cols": int(shape[1]), "nnz": int(data.numel()), "n_rhs": int(n_rhs), - "triton_analysis_ms": analysis_ms, - "triton_solve_ms": solve_ms, - "triton_time_total_ms": total_ms, - "cusparse_solve_ms": cusparse_ms, - "pytorch_solve_ms": pytorch_ms, - "pytorch_backend": pt_backend, + "triton_ms": solve_ms, + "cusparse_ms": cusparse_ms, + "pytorch_ms": pytorch_ms, "cusparse/triton": _safe_ratio(cusparse_ms, solve_ms), "pytorch/triton": _safe_ratio(pytorch_ms, solve_ms), "status": status, @@ -462,7 +482,6 @@ def _run_one_spsm_case(data, indices, indptr, shape, value_dtype, index_dtype, n "err_pt": err_pt, "err_cu": err_cu, "pytorch_reason": pytorch_reason, - "cusparse_reason": cusparse_reason, "error": None, } @@ -478,11 +497,11 @@ def run_spsm_synthetic_all(n=512, n_rhs=32): print("=" * 160) print( "PyTorch(ms)=CUDA reference (dense triangular solve). " - "FlagSparse analysis is measured separately; FlagSparse(ms) below reports solve only." + "FlagSparse(ms) below reports solve only." ) print( f"{'Fmt':>5} {'dtype':>9} {'index':>7} {'N':>6} {'RHS':>6} {'NNZ':>10} " - f"{'FS.anlys':>10} {'FS.solve':>10} {'FS.total':>10} " + f"{'FS.solve':>10} " f"{'cu.solve':>10} {'pt.solve':>10} {'cu/triton':>10} {'pt/triton':>10} " f"{'Status':>10} {'Err(Ref)':>12} {'Err(Res)':>12} {'Err(PT)':>12} {'Err(CU)':>12}" ) @@ -512,19 +531,15 @@ def run_spsm_synthetic_all(n=512, n_rhs=32): print( f"{fmt:>5} {_dtype_name(value_dtype):>9} {_dtype_name(index_dtype):>7} " f"{shape[0]:>6} {n_rhs:>6} {one['nnz']:>10} " - f"{_fmt_ms(one['triton_analysis_ms']):>10} {_fmt_ms(one['triton_solve_ms']):>10} {_fmt_ms(one['triton_time_total_ms']):>10} " - f"{_fmt_ms(one['cusparse_solve_ms']):>10} {_fmt_ms(one['pytorch_solve_ms']):>10} " + f"{_fmt_ms(one['triton_ms']):>10} " + f"{_fmt_ms(one['cusparse_ms']):>10} {_fmt_ms(one['pytorch_ms']):>10} " f"{_fmt_ratio(one['cusparse/triton']):>10} {_fmt_ratio(one['pytorch/triton']):>10} " f"{one['status']:>10} {_fmt_err(one['err_ref']):>12} {_fmt_err(one['err_res']):>12} " f"{_fmt_err(one['err_pt']):>12} {_fmt_err(one['err_cu']):>12}" ) if one["status"] in ("FAIL", "REF_FAIL"): - if one["pytorch_backend"] and one["pytorch_backend"] != "gpu_dense": - print(f" NOTE: pt_backend={one['pytorch_backend']}") if one["pytorch_reason"]: print(f" NOTE: {one['pytorch_reason']}") - if one["cusparse_reason"]: - print(f" NOTE: {one['cusparse_reason']}") print("-" * 160) print(f"Total cases: {total} Failed: {failed}") print("=" * 160) @@ -546,7 +561,7 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): print("=" * 176) print( f"{'Matrix':<28} {'dtype':>9} {'index':>7} {'N':>7} {'RHS':>6} {'NNZ':>10} " - f"{'FS.anlys':>10} {'FS.solve':>10} {'FS.total':>10} " + f"{'FS.solve':>10} " f"{'cu.solve':>10} {'pt.solve':>10} {'cu/triton':>10} {'pt/triton':>10} " f"{'Status':>10} {'Err(Ref)':>12} {'Err(Res)':>12} {'Err(PT)':>12} {'Err(CU)':>12}" ) @@ -587,19 +602,15 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): print( f"{short:<28} {base['value_dtype']:>9} {base['index_dtype']:>7} " f"{row['n_rows']:>7} {row['n_rhs']:>6} {row['nnz']:>10} " - f"{_fmt_ms(row['triton_analysis_ms']):>10} {_fmt_ms(row['triton_solve_ms']):>10} {_fmt_ms(row['triton_time_total_ms']):>10} " - f"{_fmt_ms(row['cusparse_solve_ms']):>10} {_fmt_ms(row['pytorch_solve_ms']):>10} " + f"{_fmt_ms(row['triton_ms']):>10} " + f"{_fmt_ms(row['cusparse_ms']):>10} {_fmt_ms(row['pytorch_ms']):>10} " f"{_fmt_ratio(row['cusparse/triton']):>10} {_fmt_ratio(row['pytorch/triton']):>10} " f"{row['status']:>10} {_fmt_err(row['err_ref']):>12} {_fmt_err(row['err_res']):>12} " f"{_fmt_err(row['err_pt']):>12} {_fmt_err(row['err_cu']):>12}" ) if row["status"] in ("FAIL", "REF_FAIL"): - if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_dense": - print(f" NOTE: pt_backend={row['pytorch_backend']}") if row["pytorch_reason"]: print(f" NOTE: {row['pytorch_reason']}") - if row["cusparse_reason"]: - print(f" NOTE: {row['cusparse_reason']}") except Exception as exc: err_msg = str(exc) status = "SKIP" if "SpSM requires square matrices" in err_msg else "ERROR" @@ -610,12 +621,9 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): "n_cols": "ERR", "nnz": "ERR", "n_rhs": int(n_rhs), - "triton_analysis_ms": None, - "triton_solve_ms": None, - "triton_time_total_ms": None, - "cusparse_solve_ms": None, - "pytorch_solve_ms": None, - "pytorch_backend": None, + "triton_ms": None, + "cusparse_ms": None, + "pytorch_ms": None, "cusparse/triton": None, "pytorch/triton": None, "status": status, @@ -624,7 +632,6 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): "err_pt": None, "err_cu": None, "pytorch_reason": None, - "cusparse_reason": None, "error": err_msg, } rows_out.append(row) @@ -633,7 +640,6 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): f"{short:<28} {base['value_dtype']:>9} {base['index_dtype']:>7} " f"{'ERR':>7} {int(n_rhs):>6} {'ERR':>10} " f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " - f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} " f"{'N/A':>10} {'N/A':>10} {status:>10} " f"{_fmt_err(None):>12} {_fmt_err(None):>12} {_fmt_err(None):>12} {_fmt_err(None):>12}" ) @@ -649,12 +655,9 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): "n_cols", "nnz", "n_rhs", - "triton_analysis_ms", - "triton_solve_ms", - "triton_time_total_ms", - "cusparse_solve_ms", - "pytorch_solve_ms", - "pytorch_backend", + "triton_ms", + "cusparse_ms", + "pytorch_ms", "cusparse/triton", "pytorch/triton", "status", @@ -663,14 +666,13 @@ def run_all_dtypes_spsm_csv(mtx_paths, csv_path, use_coo=False, n_rhs=32): "err_pt", "err_cu", "pytorch_reason", - "cusparse_reason", "error", ] with open(csv_path, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for row in rows_out: - w.writerow(row) + w.writerow({k: ("" if v is None else v) for k, v in _csv_export_row_spsm(row).items()}) print(f"Wrote {len(rows_out)} rows to {csv_path}") diff --git a/tests/test_spsv.py b/tests/test_spsv.py index 3baaa12..534e54a 100644 --- a/tests/test_spsv.py +++ b/tests/test_spsv.py @@ -127,6 +127,32 @@ def _status_str(ok_flag, has_value): return "FAIL" if has_value else "N/A" +def _csv_export_row_spsv(row): + return { + "matrix": row.get("matrix"), + "value_dtype": row.get("value_dtype"), + "index_dtype": row.get("index_dtype"), + "opA": row.get("opA"), + "n_rows": row.get("n_rows"), + "n_cols": row.get("n_cols"), + "nnz": row.get("nnz"), + "triton_ms": row.get("triton_ms"), + "cusparse_ms": row.get("cusparse_ms"), + "pytorch_ms": row.get("pytorch_ms"), + "cusparse/triton": row.get("cusparse/triton"), + "pytorch/triton": row.get("pytorch/triton"), + "pt_status": row.get("pt_status"), + "cu_status": row.get("cu_status"), + "status": row.get("status"), + "err_ref": row.get("err_ref"), + "err_res": row.get("err_res"), + "err_pt": row.get("err_pt"), + "err_cu": row.get("err_cu"), + "pytorch_reason": row.get("pytorch_reason"), + "error": row.get("error"), + } + + def _tol_for_dtype(dtype): if dtype in (torch.float32, torch.complex64): return 1e-4, 1e-2 @@ -1175,12 +1201,9 @@ def _finalize_csv_row( "n_rows": n_rows, "n_cols": n_cols, "nnz": nnz_out, - "triton_analysis_ms": analysis_ms, - "triton_solve_ms": t_ms, - "triton_time_total_ms": _sum_ms(analysis_ms, t_ms), - "cusparse_solve_ms": cupy_ms, - "pytorch_solve_ms": pytorch_ms, - "pytorch_backend": pt_backend, + "triton_ms": t_ms, + "cusparse_ms": cupy_ms, + "pytorch_ms": pytorch_ms, "cusparse/triton": _safe_ratio(cupy_ms, t_ms), "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), "pt_status": _status_str(ok_pt, err_pt is not None), @@ -1191,10 +1214,6 @@ def _finalize_csv_row( "err_pt": err_pt, "err_cu": err_cu, "pytorch_reason": pt_skip_reason, - "cusparse_reason": None if (cupy_ms is not None or x_cu_t is not None) else ( - "CuPy/cuSPARSE unavailable" if (cp is None or cpx_sparse is None or cpx_spsolve_triangular is None) - else "cuSPARSE solve failed" - ), "error": None, } return row, pt_skip_reason @@ -1325,12 +1344,9 @@ def _finalize_csv_row_csr_full( "n_rows": n_rows, "n_cols": n_cols, "nnz": int(data.numel()), - "triton_analysis_ms": analysis_ms, - "triton_solve_ms": t_ms, - "triton_time_total_ms": _sum_ms(analysis_ms, t_ms), - "cusparse_solve_ms": cupy_ms, - "pytorch_solve_ms": pytorch_ms, - "pytorch_backend": pt_backend, + "triton_ms": t_ms, + "cusparse_ms": cupy_ms, + "pytorch_ms": pytorch_ms, "cusparse/triton": _safe_ratio(cupy_ms, t_ms), "pytorch/triton": _safe_ratio(pytorch_ms, t_ms), "pt_status": _status_str(ok_pt, err_pt is not None), @@ -1341,10 +1357,6 @@ def _finalize_csv_row_csr_full( "err_pt": err_pt, "err_cu": err_cu, "pytorch_reason": pt_skip_reason, - "cusparse_reason": None if x_cu_t is not None else ( - "CuPy/cuSPARSE unavailable" if (cp is None or cpx_sparse is None or cpx_spsolve_triangular is None) - else "cuSPARSE solve failed" - ), "error": None, } return row, pt_skip_reason @@ -1383,7 +1395,7 @@ def run_all_supported_spsv_csr_csv( ) print( "RHS is generated directly, matching Library-main's SpSV test style. " - "FlagSparse analysis is measured separately; FlagSparse(ms) below reports solve only. " + "FlagSparse(ms) below reports solve only. " "Err(Ref)=best |FlagSparse-reference|, Err(Res)=|op(A)*x-b|, " "Err(PT)=|FlagSparse-PyTorch|, Err(CU)=|FlagSparse-cuSPARSE|. " "PASS if PyTorch / cuSPARSE reference passes. Residual is diagnostic only." @@ -1391,7 +1403,7 @@ def run_all_supported_spsv_csr_csv( print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " - f"{'FS.anlys':>10} {'FS.solve':>10} {'FS.total':>10} {'cu.solve':>10} {'pt.solve':>10} " + f"{'FS.solve':>10} {'cu.solve':>10} {'pt.solve':>10} " f"{'cu/triton':>10} {'pt/triton':>10} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) @@ -1406,22 +1418,19 @@ def run_all_supported_spsv_csr_csv( name = name + "…" n_rows, n_cols = row["n_rows"], row["n_cols"] nnz = row["nnz"] - analysis_ms = row["triton_analysis_ms"] - t_ms = row["triton_solve_ms"] - cupy_ms = row["cusparse_solve_ms"] - pytorch_ms = row["pytorch_solve_ms"] + t_ms = row["triton_ms"] + cupy_ms = row["cusparse_ms"] + pytorch_ms = row["pytorch_ms"] err_ref, err_res = row["err_ref"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " - f"{_fmt_ms(analysis_ms):>10} {_fmt_ms(t_ms):>10} {_fmt_ms(row['triton_time_total_ms']):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(pytorch_ms):>10} " + f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(pytorch_ms):>10} " f"{_fmt_speedup(cupy_ms, t_ms):>10} {_fmt_speedup(pytorch_ms, t_ms):>10} " f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if status in ("FAIL", "REF_FAIL"): - if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_sparse": - print(f" NOTE: pt_backend={row['pytorch_backend']}") if pt_skip: print(f" NOTE: {pt_skip}") except Exception as e: @@ -1436,12 +1445,9 @@ def run_all_supported_spsv_csr_csv( "n_rows": "ERR", "n_cols": "ERR", "nnz": "ERR", - "triton_analysis_ms": None, - "triton_solve_ms": None, - "triton_time_total_ms": None, - "cusparse_solve_ms": None, - "pytorch_solve_ms": None, - "pytorch_backend": None, + "triton_ms": None, + "cusparse_ms": None, + "pytorch_ms": None, "cusparse/triton": None, "pytorch/triton": None, "pt_status": "N/A", @@ -1452,7 +1458,6 @@ def run_all_supported_spsv_csr_csv( "err_pt": None, "err_cu": None, "pytorch_reason": None, - "cusparse_reason": None, "error": err_msg, } ) @@ -1461,7 +1466,7 @@ def run_all_supported_spsv_csr_csv( name = name + "…" print( f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " - f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " f"{'N/A':>10} {'N/A':>10} " f"{status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" ) @@ -1475,12 +1480,9 @@ def run_all_supported_spsv_csr_csv( "n_rows", "n_cols", "nnz", - "triton_analysis_ms", - "triton_solve_ms", - "triton_time_total_ms", - "cusparse_solve_ms", - "pytorch_solve_ms", - "pytorch_backend", + "triton_ms", + "cusparse_ms", + "pytorch_ms", "cusparse/triton", "pytorch/triton", "pt_status", @@ -1491,14 +1493,13 @@ def run_all_supported_spsv_csr_csv( "err_pt", "err_cu", "pytorch_reason", - "cusparse_reason", "error", ] with open(csv_path, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows_out: - w.writerow({k: ("" if v is None else v) for k, v in r.items()}) + w.writerow({k: ("" if v is None else v) for k, v in _csv_export_row_spsv(r).items()}) print(f"Wrote {len(rows_out)} rows to {csv_path}") @@ -1537,7 +1538,7 @@ def run_all_dtypes_spsv_coo_csv( print("-" * 150) print( f"{'Matrix':<28} {'N_rows':>7} {'N_cols':>7} {'NNZ':>10} " - f"{'FS.anlys':>10} {'FS.solve':>10} {'FS.total':>10} {'cu.solve':>10} {'pt.solve':>10} " + f"{'FS.solve':>10} {'cu.solve':>10} {'pt.solve':>10} " f"{'cu/triton':>10} {'pt/triton':>10} {'Status':>6} {'Err(Ref)':>10} {'Err(Res)':>10} {'Err(PT)':>10} {'Err(CU)':>10}" ) print("-" * 150) @@ -1552,22 +1553,19 @@ def run_all_dtypes_spsv_coo_csv( name = name + "…" n_rows, n_cols = row["n_rows"], row["n_cols"] nnz = row["nnz"] - analysis_ms = row["triton_analysis_ms"] - t_ms = row["triton_solve_ms"] - cupy_ms = row["cusparse_solve_ms"] - pytorch_ms = row["pytorch_solve_ms"] + t_ms = row["triton_ms"] + cupy_ms = row["cusparse_ms"] + pytorch_ms = row["pytorch_ms"] err_ref, err_res = row["err_ref"], row["err_res"] err_pt, err_cu = row["err_pt"], row["err_cu"] status = row["status"] print( f"{name:<28} {n_rows:>7} {n_cols:>7} {nnz:>10} " - f"{_fmt_ms(analysis_ms):>10} {_fmt_ms(t_ms):>10} {_fmt_ms(row['triton_time_total_ms']):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(pytorch_ms):>10} " + f"{_fmt_ms(t_ms):>10} {_fmt_ms(cupy_ms):>10} {_fmt_ms(pytorch_ms):>10} " f"{_fmt_speedup(cupy_ms, t_ms):>10} {_fmt_speedup(pytorch_ms, t_ms):>10} " f"{status:>6} {_fmt_err(err_ref):>10} {_fmt_err(err_res):>10} {_fmt_err(err_pt):>10} {_fmt_err(err_cu):>10}" ) if status in ("FAIL", "REF_FAIL"): - if row["pytorch_backend"] and row["pytorch_backend"] != "gpu_sparse": - print(f" NOTE: pt_backend={row['pytorch_backend']}") if pt_skip: print(f" NOTE: {pt_skip}") except Exception as e: @@ -1582,12 +1580,9 @@ def run_all_dtypes_spsv_coo_csv( "n_rows": "ERR", "n_cols": "ERR", "nnz": "ERR", - "triton_analysis_ms": None, - "triton_solve_ms": None, - "triton_time_total_ms": None, - "cusparse_solve_ms": None, - "pytorch_solve_ms": None, - "pytorch_backend": None, + "triton_ms": None, + "cusparse_ms": None, + "pytorch_ms": None, "cusparse/triton": None, "pytorch/triton": None, "pt_status": "N/A", @@ -1598,7 +1593,6 @@ def run_all_dtypes_spsv_coo_csv( "err_pt": None, "err_cu": None, "pytorch_reason": None, - "cusparse_reason": None, "error": err_msg, } ) @@ -1607,7 +1601,7 @@ def run_all_dtypes_spsv_coo_csv( name = name + "…" print( f"{name:<28} {'ERR':>7} {'ERR':>7} {'ERR':>10} " - f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " + f"{_fmt_ms(None):>10} {_fmt_ms(None):>10} {_fmt_ms(None):>10} " f"{'N/A':>10} {'N/A':>10} {status:>6} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10} {_fmt_err(None):>10}" ) print(f" {status}: {e}") @@ -1620,12 +1614,9 @@ def run_all_dtypes_spsv_coo_csv( "n_rows", "n_cols", "nnz", - "triton_analysis_ms", - "triton_solve_ms", - "triton_time_total_ms", - "cusparse_solve_ms", - "pytorch_solve_ms", - "pytorch_backend", + "triton_ms", + "cusparse_ms", + "pytorch_ms", "cusparse/triton", "pytorch/triton", "pt_status", @@ -1636,14 +1627,13 @@ def run_all_dtypes_spsv_coo_csv( "err_pt", "err_cu", "pytorch_reason", - "cusparse_reason", "error", ] with open(csv_path, "w", newline="", encoding="utf-8") as f: w = csv.DictWriter(f, fieldnames=fieldnames) w.writeheader() for r in rows_out: - w.writerow({k: ("" if v is None else v) for k, v in r.items()}) + w.writerow({k: ("" if v is None else v) for k, v in _csv_export_row_spsv(r).items()}) print(f"Wrote {len(rows_out)} rows to {csv_path}") From 590256f37dcf6688e23c9f0f58a809318479e7ab Mon Sep 17 00:00:00 2001 From: berlin020 <2261128688@qq.com> Date: Wed, 29 Apr 2026 20:06:56 +0800 Subject: [PATCH 22/22] gather update