Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/cellmapper/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@
will be mapped, and others will be ignored. For numerical data, this parameter is ignored
with a warning. Can be a single category string or a list of category strings."""

_n_batches = """\
n_batches
Number of batches to use for Jaccard-based mapping matrix computation. If None (default),
compute in a single batch. Use batch processing to reduce memory usage for large datasets."""


d = DocstringProcessor(
t=_t,
Expand All @@ -114,4 +119,5 @@
use_rep=_use_rep,
knn_dist_metric=_knn_dist_metric,
subset_categories=_subset_categories,
n_batches=_n_batches,
)
4 changes: 4 additions & 0 deletions src/cellmapper/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,7 @@ class PackageConstants:

# Threshold for recommending spectral method over iterative for matrix powers
SPECTRAL_METHOD_THRESHOLD: int = 10

# Batch processing thresholds for Jaccard-based methods
JACCARD_BATCH_WARNING_CELLS: int = 100_000 # Warn if >100k cells with Jaccard methods
JACCARD_BATCH_WARNING_NEIGHBORS: int = 20 # Warn if >20 neighbors with Jaccard methods
3 changes: 3 additions & 0 deletions src/cellmapper/model/cellmapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def compute_mapping_matrix(
self_edges: bool | None = None,
n_eigenvectors: int = 50,
eigen_solver: Literal["partial", "complete"] = "partial",
n_batches: int | None = None,
) -> None:
"""
Compute the mapping matrix for label transfer.
Expand All @@ -257,6 +258,7 @@ def compute_mapping_matrix(
Eigendecomposition method for spectral approach:
- "partial": Uses sparse eigendecomposition, faster (default)
- "complete": Uses complete eigendecomposition, exact for testing
%(n_batches)s

Returns
-------
Expand Down Expand Up @@ -292,6 +294,7 @@ def compute_mapping_matrix(
kernel_method=kernel_method,
symmetrize=symmetrize,
self_edges=self_edges,
n_batches=n_batches,
)

# Validate expected shape before creating mapping operator
Expand Down
36 changes: 25 additions & 11 deletions src/cellmapper/model/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from cellmapper.logging import logger
from cellmapper.model._knn_backend import get_backend
from cellmapper.model.neighbors import Neighbors
from cellmapper.utils import extract_neighbors_from_distances
from cellmapper.utils import compute_jaccard_kernel_matrix, extract_neighbors_from_distances


class Kernel:
Expand Down Expand Up @@ -218,6 +218,7 @@ def compute_kernel_matrix(
symmetrize: bool = False,
symmetrize_method: Literal["max", "mean"] = "max",
self_edges: bool = False,
n_batches: int | None = None,
**kwargs,
) -> None:
"""
Expand All @@ -232,6 +233,7 @@ def compute_kernel_matrix(
- "max": Take element-wise maximum between matrix and transpose (preserves strongest connections)
- "mean": Take element-wise average between matrix and transpose (smooths connections)
%(self_edges)s
%(n_batches)s
**kwargs
Additional keyword arguments for kernel computation.

Expand Down Expand Up @@ -262,14 +264,25 @@ def compute_kernel_matrix(
assert self.yx is not None, "yx neighbors must be computed"
n_neighbors = self.yx.n_neighbors

# Compute kernel matrix
kernel_matrix = (yx @ xx.T) + (yy @ xy.T)

if kernel_method == "jaccard":
kernel_matrix.data /= 4 * n_neighbors - kernel_matrix.data
elif kernel_method == "hnoca":
kernel_matrix.data /= 2 * n_neighbors - kernel_matrix.data
kernel_matrix.data = kernel_matrix.data**2
# Check if batch processing might be beneficial and warn user
assert yx is not None, "yx adjacency matrix must be available"
n_query_cells = self.yrep.shape[0]
n_reference_cells = self.xrep.shape[0]
if (
n_batches is None
and (
(n_query_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS)
or (n_reference_cells > PackageConstants.JACCARD_BATCH_WARNING_CELLS)
)
and n_neighbors > PackageConstants.JACCARD_BATCH_WARNING_NEIGHBORS
):
logger.warning(
f"Computing {kernel_method} kernel for {n_query_cells:,} x {n_reference_cells:,} cells with {n_neighbors} neighbors. "
f"Consider using batch processing (n_batches parameter) to reduce memory usage."
)

# Compute kernel matrix with optional batching using utility function
kernel_matrix = compute_jaccard_kernel_matrix(xx, yy, xy, yx, kernel_method, n_neighbors, n_batches)

elif kernel_method in PackageConstants.CONNECTIVITY_BASED_KERNELS:
# Validate self-mapping-only kernels
Expand Down Expand Up @@ -423,8 +436,9 @@ def __repr__(self):
# Kernel matrix info
if self.kernel_matrix is not None:
# Calculate sparsity percentage
total_elements = self.kernel_matrix.shape[0] * self.kernel_matrix.shape[1]
sparsity = self.kernel_matrix.nnz / total_elements
kernel_matrix = cast(csr_matrix, self.kernel_matrix) # for type checker
total_elements = kernel_matrix.shape[0] * kernel_matrix.shape[1]
sparsity = kernel_matrix.nnz / total_elements

kernel_info = [
f"kernel='{self.kernel_method}'",
Expand Down
9 changes: 6 additions & 3 deletions src/cellmapper/model/mapping_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(
"""
# Extract matrix and metadata from Kernel object if provided

if isinstance(kernel_matrix, Kernel):
# We check the type name as a string to avoid issues with module reloading
# where `isinstance` can fail unexpectedly.
if type(kernel_matrix).__name__ == "Kernel" or isinstance(kernel_matrix, Kernel):
# This is a Kernel object
kernel_obj = kernel_matrix
actual_matrix = kernel_obj.kernel_matrix
Expand All @@ -93,8 +95,7 @@ def __init__(
if is_self_mapping is None:
is_self_mapping = kernel_obj._is_self_mapping

kernel_matrix = actual_matrix
else:
elif isinstance(kernel_matrix, csr_matrix | coo_matrix | csc_matrix | np.ndarray):
# This is a raw matrix
actual_matrix = kernel_matrix

Expand All @@ -103,6 +104,8 @@ def __init__(
n_rows, n_cols = actual_matrix.shape
is_self_mapping = n_rows == n_cols
logger.info("Inferred is_self_mapping=%s from matrix shape %s", is_self_mapping, actual_matrix.shape)
else:
raise ValueError(f"Unknown kernel_matrix type: {type(kernel_matrix)}")

self.is_self_mapping = is_self_mapping
self.eigen_solver = eigen_solver
Expand Down
163 changes: 162 additions & 1 deletion src/cellmapper/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Utility functions for the CellMapper package."""

import gc

import anndata as ad
import numpy as np
import pandas as pd
from anndata import AnnData
from scipy.sparse import csr_matrix, issparse
from scipy.sparse import csr_matrix, issparse, vstack
from scipy.sparse.linalg import LinearOperator, svds
from sklearn.utils.extmath import randomized_svd

from cellmapper._docs import d
from cellmapper.constants import PackageConstants
from cellmapper.logging import logger

Expand Down Expand Up @@ -367,3 +370,161 @@ def rmatvec(v):
vt = vt[idx, :]

return u, s, vt


def apply_jaccard_transformation(kernel_matrix: csr_matrix, kernel_method: str, n_neighbors: int) -> None:
"""
Apply Jaccard or HNOCA transformation to a kernel matrix in-place.

Parameters
----------
kernel_matrix
Sparse matrix to transform in-place
kernel_method
Method to use: "jaccard" or "hnoca"
n_neighbors
Number of neighbors used for normalization

Notes
-----
This function modifies the kernel_matrix in-place to save memory.
For Jaccard: kernel = intersection / (4*k - intersection)
For HNOCA: kernel = (intersection / (2*k - intersection))^2
"""
if kernel_method == "jaccard":
kernel_matrix.data /= 4 * n_neighbors - kernel_matrix.data
elif kernel_method == "hnoca":
kernel_matrix.data /= 2 * n_neighbors - kernel_matrix.data
kernel_matrix.data = kernel_matrix.data**2
else:
raise ValueError(f"Unknown kernel method: {kernel_method}. Expected 'jaccard' or 'hnoca'.")


@d.dedent
def compute_jaccard_kernel_matrix(
xx: csr_matrix,
yy: csr_matrix,
xy: csr_matrix,
yx: csr_matrix,
kernel_method: str,
n_neighbors: int,
n_batches: int | None = None,
) -> csr_matrix:
"""
Compute Jaccard or HNOCA kernel matrix with optional batch processing.

Parameters
----------
xx, yy, xy, yx
Adjacency matrices from neighbor computations
kernel_method
Kernel method to use: "jaccard" or "hnoca"
n_neighbors
Number of neighbors for normalization
%(n_batches)s

Returns
-------
csr_matrix
Computed kernel matrix

Notes
-----
This function implements both standard and batched computation modes.
Batched mode reduces memory usage for large datasets by processing
the computation in chunks.
"""
if n_batches is None:
# Standard computation
kernel_matrix = (yx @ xx.T) + (yy @ xy.T)
apply_jaccard_transformation(kernel_matrix, kernel_method, n_neighbors)
return kernel_matrix
else:
# Batched computation
return _compute_jaccard_kernel_batched(xx, yy, xy, yx, kernel_method, n_neighbors, n_batches)


@d.dedent
def _compute_jaccard_kernel_batched(
xx: csr_matrix,
yy: csr_matrix,
xy: csr_matrix,
yx: csr_matrix,
kernel_method: str,
n_neighbors: int,
n_batches: int,
) -> csr_matrix:
"""
Compute Jaccard or HNOCA kernel matrix using batch processing.

Parameters
----------
xx, yy, xy, yx
Adjacency matrices from neighbor computations
kernel_method
Kernel method to use: "jaccard" or "hnoca"
n_neighbors
Number of neighbors for normalization
%(n_batches)s

Returns
-------
csr_matrix
Computed kernel matrix
"""
# Calculate batch size based on query dataset (yx rows)
n_query = yx.shape[0]
batch_size = int(np.ceil(n_query / n_batches))

logger.info(
"Computing %s kernel with %s batches (~%s query cells per batch)", kernel_method, n_batches, f"{batch_size:,}"
)

# Pre-allocate list to store batch results
batch_results = []

# Pre-compute xx.T once to avoid repeated transposition
xx_T = xx.T

for batch_idx in range(n_batches):
start_idx = batch_idx * batch_size
end_idx = min((batch_idx + 1) * batch_size, n_query)

logger.debug("Processing batch %s/%s: cells %s-%s", batch_idx + 1, n_batches, f"{start_idx:,}", f"{end_idx:,}")

# Extract batch slices - minimize memory footprint
yx_batch = yx[start_idx:end_idx]
yy_batch = yy[start_idx:end_idx, :]

# Compute first term: yx_batch @ xx.T
term1 = yx_batch @ xx_T

# Compute second term: yy_batch @ xy.T
term2 = yy_batch @ xy.T

# Combine terms
batch_kernel = term1 + term2

# Apply Jaccard/HNOCA transformation
apply_jaccard_transformation(batch_kernel, kernel_method, n_neighbors)

# Store result in list
batch_results.append(batch_kernel)

# Cleanup batch variables immediately to save memory
del yx_batch, yy_batch, term1, term2, batch_kernel
gc.collect()

# Combine batch results using vstack
logger.info("Combining batch results...")
kernel_matrix = vstack(batch_results, format="csr")

# Ensure we return a csr_matrix (not csr_array)
if not isinstance(kernel_matrix, csr_matrix):
kernel_matrix = csr_matrix(kernel_matrix)

# Final cleanup
del batch_results, xx_T
gc.collect()

return kernel_matrix
57 changes: 57 additions & 0 deletions tests/model/test_query_to_reference_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,60 @@ def test_map_method_with_subset_categories(self, query_reference_adata):
assert "leiden_pred" in cmap.query.obs
predicted_categories = set(cmap.query.obs["leiden_pred"].dropna().unique())
assert predicted_categories.issubset(set(subset_cats))


class TestBatchProcessingCrossMapping:
"""Test batch processing functionality for cross-mapping mode."""

def test_jaccard_batch_vs_standard_identical_results(self, query_reference_adata):
"""Test that batched and standard Jaccard computation give identical results in cross-mapping mode."""
query, reference = query_reference_adata

# Test standard computation
cm_standard = CellMapper(query, reference)
cm_standard.compute_neighbors(n_neighbors=15, use_rep="X_pca", only_yx=False)
cm_standard.compute_mapping_matrix(kernel_method="jaccard", n_batches=None)

standard_kernel = cm_standard.knn.kernel_matrix.copy()

# Test batched computation
cm_batch = CellMapper(query, reference)
cm_batch.compute_neighbors(n_neighbors=15, use_rep="X_pca", only_yx=False)
cm_batch.compute_mapping_matrix(kernel_method="jaccard", n_batches=3)

batch_kernel = cm_batch.knn.kernel_matrix

# Verify matrices are identical
assert standard_kernel.shape == batch_kernel.shape, "Kernel matrix shapes should match"
assert standard_kernel.nnz == batch_kernel.nnz, "Number of non-zero elements should match"
assert (standard_kernel - batch_kernel).nnz == 0, "Kernel matrices should be identical"

# Verify mapping results are identical
cm_standard.map_obs(key="leiden")
cm_batch.map_obs(key="leiden")

# Check that predictions are identical
assert cm_standard.query.obs["leiden_pred"].equals(cm_batch.query.obs["leiden_pred"]), (
"Label predictions should be identical between standard and batch computation"
)

@pytest.mark.parametrize(
"kernel_method,n_batches", [("jaccard", None), ("jaccard", 2), ("hnoca", None), ("hnoca", 3)]
)
def test_jaccard_hnoca_batch_parametrized(self, query_reference_adata, kernel_method, n_batches):
"""Test both Jaccard and HNOCA kernels with different batch configurations."""
query, reference = query_reference_adata

cm = CellMapper(query, reference)
cm.compute_neighbors(n_neighbors=12, use_rep="X_pca", only_yx=False)
cm.compute_mapping_matrix(kernel_method=kernel_method, n_batches=n_batches)

# Verify kernel matrix properties
expected_shape = (query.n_obs, reference.n_obs)
assert cm.knn.kernel_matrix is not None, "Kernel matrix should be computed"
assert cm.knn.kernel_matrix.shape == expected_shape, f"Shape should match {expected_shape}"
assert cm.knn.kernel_matrix.nnz > 0, "Kernel matrix should have non-zero elements"

# Verify mapping works
cm.map_obs(key="leiden")
assert "leiden_pred" in cm.query.obs, "Label predictions should be generated"
Loading
Loading