diff --git a/src/rapids_singlecell/preprocessing/_scrublet/__init__.py b/src/rapids_singlecell/preprocessing/_scrublet/__init__.py index 39947997..90458380 100644 --- a/src/rapids_singlecell/preprocessing/_scrublet/__init__.py +++ b/src/rapids_singlecell/preprocessing/_scrublet/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +import gc +import warnings from typing import TYPE_CHECKING import cupy as cp @@ -12,6 +14,9 @@ from scanpy.get import _get_obs_rep from rapids_singlecell import preprocessing as pp +from rapids_singlecell._compat import DaskArray +from rapids_singlecell.get import X_to_GPU +from rapids_singlecell.preprocessing._utils import _check_gpu_X from . import pipeline from .core import Scrublet @@ -74,6 +79,7 @@ def scrublet( preprocessing, simulate doublets with :func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, and run the core scrublet function :func:`~rapids_singlecell.pp.scrublet` with ``adata_sim`` set. + Scrublet can also be run with a `dask array` if a batch key is provided. Please make sure that each batch can fit into memory. In addition to that scrublet will not return the full scrublet results, but only the `doublet score` and `predicted doublet`, not `.uns['scrublet']`. `adata_sim` is not supported for`dask arrays`. Parameters ---------- @@ -89,6 +95,7 @@ def scrublet( :func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, with same number of vars as adata. This should have been built from adata_obs after filtering genes and cells and selecting highly-variable genes. + Not supported for dask arrays. batch_key Optional :attr:`~anndata.AnnData.obs` column name discriminating between batches. sim_doublet_ratio @@ -180,11 +187,10 @@ def scrublet( start = logg.info("Running Scrublet") - adata_obs = adata.copy() - def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # With no adata_sim we assume the regular use case, starting with raw # counts and simulating doublets + ad_obs.X = X_to_GPU(ad_obs.X) if ad_sim is None: pp.filter_genes(ad_obs, min_cells=3, verbose=False) @@ -241,7 +247,8 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): verbose=verbose, ) - return {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]} + out = {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]} + return out if batch_key is not None: if batch_key not in adata.obs.keys(): @@ -252,33 +259,74 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # Run Scrublet independently on batches and return just the # scrublet-relevant parts of the objects to add to the input object - batches = np.unique(adata.obs[batch_key]) - scrubbed = [ - _run_scrublet( - adata_obs[adata_obs.obs[batch_key] == batch].copy(), - adata_sim, + if isinstance(adata.X, DaskArray): + # Define function to process each batch chunk + def _process_batch_chunk(X_chunk): + """Process a single batch chunk through Scrublet.""" + batch_adata = AnnData(X_chunk) + batch_results = _run_scrublet(batch_adata, None) + return np.array( + batch_results["obs"][["doublet_score", "predicted_doublet"]] + ).astype(np.float64) + + # Get batch information and sort data by batch + batch_codes = adata.obs[batch_key].astype("category").cat.codes + sort_indices = np.argsort(batch_codes) + X_sorted = adata.X[sort_indices] + + # Calculate chunk sizes based on batch sizes + batch_sizes = np.bincount(batch_codes.iloc[sort_indices]) + X_rechunked = X_sorted.rechunk((tuple(batch_sizes), adata.X.shape[1])) + + # Process all batches in parallel using map_blocks + batch_results = X_rechunked.map_blocks( + _process_batch_chunk, + meta=np.array([], dtype=np.float64), + dtype=np.float64, + chunks=(X_rechunked.chunks[0], 2), ) - for batch in batches - ] - scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]) - - # Now reset the obs to get the scrublet scores - - adata.obs = scrubbed_obs.loc[adata.obs_names.values] - - # Save the .uns from each batch separately - - adata.uns["scrublet"] = {} - adata.uns["scrublet"]["batches"] = dict( - zip(batches, [scrub["uns"] for scrub in scrubbed]) - ) - # Record that we've done batched analysis, so e.g. the plotting - # function knows what to do. - - adata.uns["scrublet"]["batched_by"] = batch_key + # Convert results to DataFrame and restore original order + results_df = pd.DataFrame( + batch_results.compute(), columns=["doublet_score", "predicted_doublet"] + ) + final_results = results_df.iloc[np.argsort(sort_indices)] + + # Update the original AnnData object with results + adata.obs["doublet_score"] = final_results["doublet_score"].values + adata.obs["predicted_doublet"] = final_results[ + "predicted_doublet" + ].values.astype(bool) + adata.uns["scrublet"] = {"batched_by": batch_key} + + else: + batches = np.unique(adata.obs[batch_key]) + scrubbed = [ + _run_scrublet( + adata[adata.obs[batch_key] == batch].copy(), + adata_sim, + ) + for batch in batches + ] + + scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]) + # Now reset the obs to get the scrublet scores + adata.obs = scrubbed_obs.loc[adata.obs_names.values] + + # Save the .uns from each batch separately + + adata.uns["scrublet"] = {} + adata.uns["scrublet"]["batches"] = dict( + zip(batches, [scrub["uns"] for scrub in scrubbed]) + ) + adata.uns["scrublet"]["batched_by"] = batch_key else: + adata_obs = adata.copy() + if isinstance(adata_obs.X, DaskArray): + raise ValueError( + "Dask arrays are not supported for Scrublet without a batch key. Please provide a batch key." + ) scrubbed = _run_scrublet(adata_obs, adata_sim) # Copy outcomes to input object from our processed version diff --git a/tests/dask/test_dask_scrublet.py b/tests/dask/test_dask_scrublet.py new file mode 100644 index 00000000..ad3d11c2 --- /dev/null +++ b/tests/dask/test_dask_scrublet.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import cupy as cp +import numpy as np +import pytest +from cupyx.scipy import sparse as cusparse +from scanpy.datasets import paul15, pbmc3k + +import rapids_singlecell as rsc +from rapids_singlecell._testing import ( + as_dense_cupy_dask_array, + as_sparse_cupy_dask_array, +) + + +@pytest.mark.parametrize("data_kind", ["sparse", "dense"]) +def test_dask_scrublet(data_kind): + if data_kind == "sparse": + adata_1 = pbmc3k()[200:400].copy() + adata_2 = pbmc3k()[200:400].copy() + adata_2.X = cusparse.csr_matrix(adata_2.X.astype(np.float64)) + adata_1.X = as_sparse_cupy_dask_array(adata_1.X.astype(np.float64)) + elif data_kind == "dense": + adata_1 = paul15()[200:400].copy() + adata_2 = paul15()[200:400].copy() + adata_2.X = cp.array(adata_2.X.astype(np.float64)) + adata_1.X = as_dense_cupy_dask_array(adata_1.X.astype(np.float64)) + else: + raise ValueError(f"Unknown data_kind {data_kind}") + + batch = np.random.randint(0, 2, size=adata_1.shape[0]) + adata_1.obs["batch"] = batch + adata_2.obs["batch"] = batch + rsc.pp.scrublet(adata_1, batch_key="batch", verbose=False) + + # sort adata_2 to compare results + batch_codes = adata_2.obs["batch"].astype("category").cat.codes + order = np.argsort(batch_codes) + adata_2 = adata_2[order] + + rsc.pp.scrublet(adata_2, batch_key="batch", verbose=False) + adata_2 = adata_2[np.argsort(order)] + + np.testing.assert_allclose( + adata_1.obs["doublet_score"], adata_2.obs["doublet_score"] + ) + np.testing.assert_array_equal( + adata_1.obs["predicted_doublet"], adata_2.obs["predicted_doublet"] + )