Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/scanpy_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ Any transformation of the data matrix that is not a tool. Other than `tools`, pr
:toctree: generated/

tl.rank_genes_groups_logreg
tl.rank_genes_groups_wilcoxon
```

## Plotting
Expand Down
2 changes: 1 addition & 1 deletion src/rapids_singlecell/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._draw_graph import draw_graph
from ._embedding_density import embedding_density
from ._pymde import mde
from ._rank_gene_groups import rank_genes_groups_logreg
from ._rank_gene_groups import rank_genes_groups_logreg, rank_genes_groups_wilcoxon
from ._score_genes import score_genes, score_genes_cell_cycle
from ._tsne import tsne
from ._umap import umap
256 changes: 256 additions & 0 deletions src/rapids_singlecell/tools/_rank_gene_groups.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal

import cupy as cp
import cupyx.scipy.special as cupyx_special
import numpy as np
import pandas as pd
from statsmodels.stats.multitest import multipletests

from rapids_singlecell._compat import DaskArray, _meta_dense

Expand Down Expand Up @@ -228,3 +231,256 @@ def rank_genes_groups_logreg(
}
adata.uns["rank_genes_groups"]["scores"] = scores
adata.uns["rank_genes_groups"]["names"] = names


EPS = 1e-9


def _choose_chunk_size(requested: int | None, n_obs: int, dtype_size: int = 8) -> int:
if requested is not None:
return int(requested)
try:
free_mem, _ = cp.cuda.runtime.memGetInfo()
except cp.cuda.runtime.CUDARuntimeError:
return 500
bytes_per_gene = n_obs * dtype_size * 4
if bytes_per_gene == 0:
return 500
max_genes = int(0.6 * free_mem / bytes_per_gene)
return max(min(max_genes, 1000), 100)


def _average_ranks(matrix: cp.ndarray) -> cp.ndarray:
ranks = cp.empty_like(matrix, dtype=cp.float64)
for idx in range(matrix.shape[1]):
column = matrix[:, idx]
sorter = cp.argsort(column)
sorted_column = column[sorter]
unique = cp.concatenate(
(cp.array([True]), sorted_column[1:] != sorted_column[:-1])
)
dense = cp.empty(column.size, dtype=cp.int64)
dense[sorter] = cp.cumsum(unique)
boundaries = cp.concatenate((cp.flatnonzero(unique), cp.array([unique.size])))
ranks[:, idx] = 0.5 * (boundaries[dense] + boundaries[dense - 1] + 1.0)
return ranks


def _tie_correction(ranks: cp.ndarray) -> cp.ndarray:
correction = cp.ones(ranks.shape[1], dtype=cp.float64)
for idx in range(ranks.shape[1]):
column = cp.sort(ranks[:, idx])
boundaries = cp.concatenate(
(
cp.array([0]),
cp.flatnonzero(column[1:] != column[:-1]) + 1,
cp.array([column.size]),
)
)
differences = cp.diff(boundaries).astype(cp.float64)
size = cp.float64(column.size)
if size >= 2:
correction[idx] = 1.0 - (differences**3 - differences).sum() / (
size**3 - size
)
return correction


def rank_genes_groups_wilcoxon(
adata: AnnData,
groupby: str,
*,
groups: Literal["all"] | Iterable[str] = "all",
use_raw: bool | None = None,
reference: str = "rest",
n_genes: int | None = None,
tie_correct: bool = False,
layer: str | None = None,
chunk_size: int | None = None,
corr_method: str = "benjamini-hochberg",
) -> None:
if corr_method not in {"benjamini-hochberg", "bonferroni"}:
msg = "corr_method must be either 'benjamini-hochberg' or 'bonferroni'."
raise ValueError(msg)
if reference != "rest":
msg = "Only reference='rest' is currently supported for the GPU Wilcoxon test."
raise NotImplementedError(msg)

if groups == "all" or groups is None:
groups_order = "all"
elif isinstance(groups, str | int):
raise ValueError("Specify a sequence of groups")
else:
groups_order = list(groups)
if isinstance(groups_order[0], int):
groups_order = [str(n) for n in groups_order]

labels = pd.Series(adata.obs[groupby]).reset_index(drop="True")
groups_order, groups_masks = _select_groups(labels, groups_order)

group_sizes = groups_masks.sum(axis=1).astype(np.int64)
n_cells = labels.shape[0]
for name, size in zip(groups_order, group_sizes, strict=False):
rest = n_cells - size
if size <= 25 or rest <= 25:
warnings.warn(
f"Group {name} has size {size} (rest {rest}); normal approximation "
"of the Wilcoxon statistic may be inaccurate.",
RuntimeWarning,
)

if layer and use_raw is True:
raise ValueError("Cannot specify `layer` and have `use_raw=True`.")
elif layer:
X = adata.layers[layer]
var_names = adata.var_names
elif use_raw is None and adata.raw:
print("defaulting to using `.raw`")
X = adata.raw.X
var_names = adata.raw.var_names
elif use_raw is True:
X = adata.raw.X
var_names = adata.raw.var_names
else:
X = adata.X
var_names = adata.var_names

if hasattr(X, "toarray"):
X = X.toarray()

n_cells, n_total_genes = X.shape
n_top = n_total_genes if n_genes is None else min(n_genes, n_total_genes)

group_matrix = cp.asarray(groups_masks.T, dtype=cp.float64)
group_sizes_dev = cp.asarray(group_sizes, dtype=cp.float64)
rest_sizes = n_cells - group_sizes_dev

base = adata.uns.get("log1p", {}).get("base")
if base is not None:
log_expm1 = lambda arr: cp.expm1(arr * cp.log(base))
else:
log_expm1 = cp.expm1

chunk_width = _choose_chunk_size(chunk_size, n_cells)
group_keys = [str(key) for key in groups_order]

scores: dict[str, list[np.ndarray]] = {key: [] for key in group_keys}
logfc: dict[str, list[np.ndarray]] = {key: [] for key in group_keys}
pvals: dict[str, list[np.ndarray]] = {key: [] for key in group_keys}
gene_indices: dict[str, list[np.ndarray]] = {key: [] for key in group_keys}

for start in range(0, n_total_genes, chunk_width):
stop = min(start + chunk_width, n_total_genes)
block = cp.asarray(X[:, start:stop], dtype=cp.float64)
ranks = _average_ranks(block)
if tie_correct:
tie_corr = _tie_correction(ranks)
else:
tie_corr = cp.ones(ranks.shape[1], dtype=cp.float64)

rank_sums = group_matrix.T @ ranks
expected = group_sizes_dev[:, None] * (n_cells + 1) / 2.0
variance = tie_corr[None, :] * group_sizes_dev[:, None] * rest_sizes[:, None]
variance *= (n_cells + 1) / 12.0
std = cp.sqrt(variance)
z = (rank_sums - expected) / std
cp.nan_to_num(z, copy=False)
p_values = 2.0 * (1.0 - cupyx_special.ndtr(cp.abs(z)))

group_sums = group_matrix.T @ block
group_means = group_sums / group_sizes_dev[:, None]
total_mean = cp.mean(block, axis=0)
rest_sum = total_mean * n_cells - group_means * group_sizes_dev[:, None]
mean_rest = rest_sum / rest_sizes[:, None]
numerator = log_expm1(group_means) + EPS
denominator = log_expm1(mean_rest) + EPS
log_fold = cp.log2(numerator / denominator)

indices = np.arange(start, stop, dtype=int)
z_host = z.get()
p_host = p_values.get()
logfc_host = log_fold.get()

for idx, key in enumerate(group_keys):
scores[key].append(z_host[idx])
logfc[key].append(logfc_host[idx])
pvals[key].append(p_host[idx])
gene_indices[key].append(indices)

var_array = np.asarray(var_names)
structured = {}
for key in group_keys:
all_scores = (
np.concatenate(scores[key]) if scores[key] else np.empty(0, dtype=float)
)
all_logfc = (
np.concatenate(logfc[key]) if logfc[key] else np.empty(0, dtype=float)
)
all_pvals = (
np.concatenate(pvals[key]) if pvals[key] else np.empty(0, dtype=float)
)
all_genes = (
np.concatenate(gene_indices[key])
if gene_indices[key]
else np.empty(0, dtype=int)
)

clean = np.array(all_pvals, copy=True)
clean[np.isnan(clean)] = 1.0
if clean.size and corr_method == "benjamini-hochberg":
_, adjusted, _, _ = multipletests(clean, alpha=0.05, method="fdr_bh")
elif clean.size:
adjusted = np.minimum(clean * n_total_genes, 1.0)
else:
adjusted = np.array([], dtype=float)

if all_scores.size:
order = np.argsort(all_scores)[::-1]
else:
order = np.empty(0, dtype=int)
keep = order[: min(n_top, order.size)]

structured[key] = {
"scores": all_scores[keep].astype(np.float32, copy=False),
"logfc": all_logfc[keep].astype(np.float32, copy=False),
"pvals": clean[keep].astype(np.float64, copy=False),
"pvals_adj": adjusted[keep].astype(np.float64, copy=False),
"names": var_array[all_genes[keep]].astype("U50", copy=False),
}

dtype_scores = [(key, "float32") for key in group_keys]
dtype_names = [(key, "U50") for key in group_keys]
dtype_logfc = [(key, "float32") for key in group_keys]
dtype_pvals = [(key, "float64") for key in group_keys]

adata.uns["rank_genes_groups"] = {
"params": {
"groupby": groupby,
"method": "wilcoxon",
"reference": reference,
"use_raw": use_raw,
"tie_correct": tie_correct,
"layer": layer,
"corr_method": corr_method,
},
"scores": np.rec.fromarrays(
[structured[key]["scores"] for key in group_keys],
dtype=dtype_scores,
),
"names": np.rec.fromarrays(
[structured[key]["names"] for key in group_keys],
dtype=dtype_names,
),
"logfoldchanges": np.rec.fromarrays(
[structured[key]["logfc"] for key in group_keys],
dtype=dtype_logfc,
),
"pvals": np.rec.fromarrays(
[structured[key]["pvals"] for key in group_keys],
dtype=dtype_pvals,
),
"pvals_adj": np.rec.fromarrays(
[structured[key]["pvals_adj"] for key in group_keys],
dtype=dtype_pvals,
),
}
Loading
Loading