From 73f2ba65981bfcf014d80115307588c2b7f0693d Mon Sep 17 00:00:00 2001 From: Intron7 Date: Mon, 14 Jul 2025 10:06:53 +0200 Subject: [PATCH 1/3] tester --- .../preprocessing/_scrublet/__init__.py | 41 +++++++++++++++---- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/src/rapids_singlecell/preprocessing/_scrublet/__init__.py b/src/rapids_singlecell/preprocessing/_scrublet/__init__.py index 39947997..8c778079 100644 --- a/src/rapids_singlecell/preprocessing/_scrublet/__init__.py +++ b/src/rapids_singlecell/preprocessing/_scrublet/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING import cupy as cp @@ -12,6 +13,8 @@ from scanpy.get import _get_obs_rep from rapids_singlecell import preprocessing as pp +from rapids_singlecell._compat import DaskArray +from rapids_singlecell.preprocessing._utils import _check_gpu_X from . import pipeline from .core import Scrublet @@ -180,11 +183,11 @@ def scrublet( start = logg.info("Running Scrublet") - adata_obs = adata.copy() - def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # With no adata_sim we assume the regular use case, starting with raw # counts and simulating doublets + if isinstance(ad_obs.X, DaskArray): + ad_obs.X = ad_obs.X.compute() if ad_sim is None: pp.filter_genes(ad_obs, min_cells=3, verbose=False) @@ -243,6 +246,8 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]} + _check_gpu_X(adata.X, allow_dask=True) + if batch_key is not None: if batch_key not in adata.obs.keys(): raise ValueError( @@ -253,13 +258,26 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # scrublet-relevant parts of the objects to add to the input object batches = np.unique(adata.obs[batch_key]) - scrubbed = [ - _run_scrublet( - adata_obs[adata_obs.obs[batch_key] == batch].copy(), - adata_sim, - ) - for batch in batches - ] + if isinstance(adata.X, DaskArray): + from dask.distributed import get_client + + client = get_client() + futures = [ + client.submit( + _run_scrublet, adata[adata.obs[batch_key] == batch].copy(), None + ) + for batch in batches + ] + scrubbed = client.gather(futures) + else: + scrubbed = [ + _run_scrublet( + adata[adata.obs[batch_key] == batch].copy(), + adata_sim, + ) + for batch in batches + ] + scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]) # Now reset the obs to get the scrublet scores @@ -279,6 +297,11 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): adata.uns["scrublet"]["batched_by"] = batch_key else: + adata_obs = adata.copy() + if isinstance(adata_obs.X, DaskArray): + warnings.warn( + "Dask arrays are only supported for Scrublet with a batch key. We are computing the object to run Scrublet." + ) scrubbed = _run_scrublet(adata_obs, adata_sim) # Copy outcomes to input object from our processed version From 77aa8c8a212e9c530a3a6667ef9455478a58ab89 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Tue, 15 Jul 2025 02:48:36 -0700 Subject: [PATCH 2/3] new implementation & tests --- .../preprocessing/_scrublet/__init__.py | 79 ++++++++++++------- tests/dask/test_dask_scrublet.py | 49 ++++++++++++ 2 files changed, 101 insertions(+), 27 deletions(-) create mode 100644 tests/dask/test_dask_scrublet.py diff --git a/src/rapids_singlecell/preprocessing/_scrublet/__init__.py b/src/rapids_singlecell/preprocessing/_scrublet/__init__.py index 8c778079..7f23224a 100644 --- a/src/rapids_singlecell/preprocessing/_scrublet/__init__.py +++ b/src/rapids_singlecell/preprocessing/_scrublet/__init__.py @@ -77,6 +77,7 @@ def scrublet( preprocessing, simulate doublets with :func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, and run the core scrublet function :func:`~rapids_singlecell.pp.scrublet` with ``adata_sim`` set. + Scrublet can also be run with a `dask array` if a batch key is provided. Please make sure that each batch can fit into memory. In addition to that scrublet will not return the full scrublet results, but only the `doublet score` and `predicted doublet`, not `.uns['scrublet']`. `adata_sim` is not supported for`dask arrays`. Parameters ---------- @@ -92,6 +93,7 @@ def scrublet( :func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, with same number of vars as adata. This should have been built from adata_obs after filtering genes and cells and selecting highly-variable genes. + Not supported for dask arrays. batch_key Optional :attr:`~anndata.AnnData.obs` column name discriminating between batches. sim_doublet_ratio @@ -257,19 +259,48 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # Run Scrublet independently on batches and return just the # scrublet-relevant parts of the objects to add to the input object - batches = np.unique(adata.obs[batch_key]) if isinstance(adata.X, DaskArray): - from dask.distributed import get_client + # Define function to process each batch chunk + def _process_batch_chunk(X_chunk): + """Process a single batch chunk through Scrublet.""" + batch_adata = AnnData(X_chunk) + batch_results = _run_scrublet(batch_adata, None) + return np.array( + batch_results["obs"][["doublet_score", "predicted_doublet"]] + ).astype(np.float64) + + # Get batch information and sort data by batch + batch_codes = adata.obs[batch_key].astype("category").cat.codes + sort_indices = np.argsort(batch_codes) + X_sorted = adata.X[sort_indices] + + # Calculate chunk sizes based on batch sizes + batch_sizes = np.bincount(batch_codes.iloc[sort_indices]) + X_rechunked = X_sorted.rechunk((tuple(batch_sizes), adata.X.shape[1])) + + # Process all batches in parallel using map_blocks + batch_results = X_rechunked.map_blocks( + _process_batch_chunk, + meta=np.array([], dtype=np.float64), + dtype=np.float64, + chunks=(X_rechunked.chunks[0], 2), + ) + + # Convert results to DataFrame and restore original order + results_df = pd.DataFrame( + batch_results.compute(), columns=["doublet_score", "predicted_doublet"] + ) + final_results = results_df.iloc[np.argsort(sort_indices)] + + # Update the original AnnData object with results + adata.obs["doublet_score"] = final_results["doublet_score"].values + adata.obs["predicted_doublet"] = final_results[ + "predicted_doublet" + ].values.astype(bool) + adata.uns["scrublet"] = {"batched_by": batch_key} - client = get_client() - futures = [ - client.submit( - _run_scrublet, adata[adata.obs[batch_key] == batch].copy(), None - ) - for batch in batches - ] - scrubbed = client.gather(futures) else: + batches = np.unique(adata.obs[batch_key]) scrubbed = [ _run_scrublet( adata[adata.obs[batch_key] == batch].copy(), @@ -278,29 +309,23 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): for batch in batches ] - scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]) + scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]) + # Now reset the obs to get the scrublet scores + adata.obs = scrubbed_obs.loc[adata.obs_names.values] - # Now reset the obs to get the scrublet scores + # Save the .uns from each batch separately - adata.obs = scrubbed_obs.loc[adata.obs_names.values] - - # Save the .uns from each batch separately - - adata.uns["scrublet"] = {} - adata.uns["scrublet"]["batches"] = dict( - zip(batches, [scrub["uns"] for scrub in scrubbed]) - ) - - # Record that we've done batched analysis, so e.g. the plotting - # function knows what to do. - - adata.uns["scrublet"]["batched_by"] = batch_key + adata.uns["scrublet"] = {} + adata.uns["scrublet"]["batches"] = dict( + zip(batches, [scrub["uns"] for scrub in scrubbed]) + ) + adata.uns["scrublet"]["batched_by"] = batch_key else: adata_obs = adata.copy() if isinstance(adata_obs.X, DaskArray): - warnings.warn( - "Dask arrays are only supported for Scrublet with a batch key. We are computing the object to run Scrublet." + raise ValueError( + "Dask arrays are not supported for Scrublet without a batch key. Please provide a batch key." ) scrubbed = _run_scrublet(adata_obs, adata_sim) diff --git a/tests/dask/test_dask_scrublet.py b/tests/dask/test_dask_scrublet.py new file mode 100644 index 00000000..ad3d11c2 --- /dev/null +++ b/tests/dask/test_dask_scrublet.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import cupy as cp +import numpy as np +import pytest +from cupyx.scipy import sparse as cusparse +from scanpy.datasets import paul15, pbmc3k + +import rapids_singlecell as rsc +from rapids_singlecell._testing import ( + as_dense_cupy_dask_array, + as_sparse_cupy_dask_array, +) + + +@pytest.mark.parametrize("data_kind", ["sparse", "dense"]) +def test_dask_scrublet(data_kind): + if data_kind == "sparse": + adata_1 = pbmc3k()[200:400].copy() + adata_2 = pbmc3k()[200:400].copy() + adata_2.X = cusparse.csr_matrix(adata_2.X.astype(np.float64)) + adata_1.X = as_sparse_cupy_dask_array(adata_1.X.astype(np.float64)) + elif data_kind == "dense": + adata_1 = paul15()[200:400].copy() + adata_2 = paul15()[200:400].copy() + adata_2.X = cp.array(adata_2.X.astype(np.float64)) + adata_1.X = as_dense_cupy_dask_array(adata_1.X.astype(np.float64)) + else: + raise ValueError(f"Unknown data_kind {data_kind}") + + batch = np.random.randint(0, 2, size=adata_1.shape[0]) + adata_1.obs["batch"] = batch + adata_2.obs["batch"] = batch + rsc.pp.scrublet(adata_1, batch_key="batch", verbose=False) + + # sort adata_2 to compare results + batch_codes = adata_2.obs["batch"].astype("category").cat.codes + order = np.argsort(batch_codes) + adata_2 = adata_2[order] + + rsc.pp.scrublet(adata_2, batch_key="batch", verbose=False) + adata_2 = adata_2[np.argsort(order)] + + np.testing.assert_allclose( + adata_1.obs["doublet_score"], adata_2.obs["doublet_score"] + ) + np.testing.assert_array_equal( + adata_1.obs["predicted_doublet"], adata_2.obs["predicted_doublet"] + ) From d918588fcc8bd15e26b08029bcd6cdba8c80a432 Mon Sep 17 00:00:00 2001 From: Intron7 Date: Fri, 18 Jul 2025 04:06:58 -0700 Subject: [PATCH 3/3] allow CPU to reduce memory --- .../preprocessing/_scrublet/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rapids_singlecell/preprocessing/_scrublet/__init__.py b/src/rapids_singlecell/preprocessing/_scrublet/__init__.py index 7f23224a..90458380 100644 --- a/src/rapids_singlecell/preprocessing/_scrublet/__init__.py +++ b/src/rapids_singlecell/preprocessing/_scrublet/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +import gc import warnings from typing import TYPE_CHECKING @@ -14,6 +15,7 @@ from rapids_singlecell import preprocessing as pp from rapids_singlecell._compat import DaskArray +from rapids_singlecell.get import X_to_GPU from rapids_singlecell.preprocessing._utils import _check_gpu_X from . import pipeline @@ -188,8 +190,7 @@ def scrublet( def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # With no adata_sim we assume the regular use case, starting with raw # counts and simulating doublets - if isinstance(ad_obs.X, DaskArray): - ad_obs.X = ad_obs.X.compute() + ad_obs.X = X_to_GPU(ad_obs.X) if ad_sim is None: pp.filter_genes(ad_obs, min_cells=3, verbose=False) @@ -246,9 +247,8 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): verbose=verbose, ) - return {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]} - - _check_gpu_X(adata.X, allow_dask=True) + out = {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]} + return out if batch_key is not None: if batch_key not in adata.obs.keys():