diff --git a/pertpy/tools/__init__.py b/pertpy/tools/__init__.py index 45969728..69ba58ff 100644 --- a/pertpy/tools/__init__.py +++ b/pertpy/tools/__init__.py @@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs): TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS) WilcoxonTest = lazy_import("pertpy.tools._differential_gene_expression", "WilcoxonTest", DE_EXTRAS) + __all__ = [ "Augur", "Cinemaot", diff --git a/pertpy/tools/_milo.py b/pertpy/tools/_milo.py index 10874ac5..9e74684a 100644 --- a/pertpy/tools/_milo.py +++ b/pertpy/tools/_milo.py @@ -3,7 +3,7 @@ import random import re from importlib.util import find_spec -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal, Optional import matplotlib.pyplot as plt import numpy as np @@ -12,21 +12,32 @@ import seaborn as sns from anndata import AnnData from lamin_utils import logger +from matplotlib.axes import Axes +from matplotlib.cm import ScalarMappable +from matplotlib.colors import Colormap, Normalize from mudata import MuData from pertpy._doc import _doc_params, doc_common_plot_args if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Collection, Sequence from matplotlib.axes import Axes from matplotlib.colors import Colormap from matplotlib.figure import Figure -from scipy.sparse import csr_matrix +from scipy.sparse import coo_matrix, csr_matrix, issparse, spmatrix from sklearn.metrics.pairwise import euclidean_distances +def _is_counts(array: np.ndarray | spmatrix) -> bool: + """Check if the array is a count matrix.""" + if issparse(array): + return bool(np.all(np.mod(array.data, 1) == 0)) + else: + return bool(np.all(np.mod(array, 1) == 0)) + + class Milo: """Python implementation of Milo.""" @@ -371,7 +382,8 @@ def da_nhoods( # Fit NB-GLM counts_filtered = count_mat[np.ix_(keep_nhoods, keep_smp)] lib_size_filtered = lib_size[keep_smp] - count_mat_r = numpy2ri.py2rpy(counts_filtered) + with localconverter(ro.default_converter + numpy2ri.converter): + count_mat_r = numpy2ri.py2rpy(counts_filtered) lib_size_r = FloatVector(lib_size_filtered) dge = edgeR.DGEList(counts=count_mat_r, lib_size=lib_size_r) dge = edgeR.calcNormFactors(dge, method="TMM") @@ -878,6 +890,190 @@ def plot_nhood_graph( # pragma: no cover # noqa: D417 plt.show() return None + # In plot_nhood_annotation color_map, palette, and ax are not documented, and not part of common_plot_args + # Should I add them or will they be part of common_plot_args in the future? + @_doc_params(common_plot_args=doc_common_plot_args) + def plot_nhood_annotation( # pragma: no cover # noqa: D417 + self, + mdata: MuData, + *, + adata_key: str = "milo", + annotation_key: str | None = "nhood_annotation", + alpha: float = 0.1, + min_logFC: float = 0.0, + min_size: int = 10, + plot_edges: bool = False, + title: str = "DA log-Fold Change", + color_map: Colormap | str | None = None, + palette: str | Sequence[str] | None = None, + ax: Axes | None = None, + return_fig: bool = False, + **kwargs, + ) -> Figure | None: + """Visualize Milo differential-abundance results on the neighborhood graph. + + By default, neighborhoods are colored by filtered logFC (``|logFC|`` ≥ `min_logFC` + and SpatialFDR ≤ `alpha`). If `annotation_key` is provided, this column from + `mdata[adata_key].obs` will be used instead for coloring. + + Args: + mdata: A MuData object with: + - mdata["milo"]: Milo-neighborhood AnnData (transposed). + - mdata[adata_key]: AnnData containing annotation in `.obs`. + adata_key: Key for the AnnData within `mdata` that contains `.obs[annotation_key]`. + Defaults to "milo". + annotation_key: Name of the `.obs` column to use for coloring. If not None, + disables logFC-based coloring. Defaults to "nhood_annotation". + alpha: Significance threshold for SpatialFDR. Used only if `annotation_key` is None. + Defaults to 0.1. + min_logFC: Minimum absolute logFC to show. Used only if `annotation_key` is None. + Defaults to 0.0. + min_size: Scaling factor for node size. Actual size = `Nhood_size × min_size`. + Defaults to 10. + plot_edges: Whether to plot edges in the neighborhood overlap graph. + Defaults to False. + title: Title for the plot. Ignored if `annotation_key` is provided. + Defaults to "DA log-Fold Change". + color_map: Colormap to use for coloring. + palette: Name of Seaborn color palette for violinplots. + Defaults to pre-defined category colors for violinplots. + ax: Axes to plot on. + {common_plot_args} + **kwargs: Additional keyword arguments to pass directly to `scanpy.pl.embedding`. + + Returns: + matplotlib.figure.Figure or None: The matplotlib Figure, if `return_fig` is True; + otherwise, displays the plot and returns None. + + Examples: + >>> import pertpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> sc.tl.umap(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.da_nhoods(mdata, + >>> design='~label', + >>> model_contrasts='labelwithdraw_15d_Cocaine-labelwithdraw_48h_Cocaine') + >>> milo.build_nhood_graph(mdata) + >>> milo.group_nhoods(mdata) + >>> milo.plot_nhood_annotation(mdata, annotation_key="nhood_groups") + + Preview: + .. image:: /_static/docstring_previews/milo_nhood_annotation.png + """ + # ------------------------------------------------------------------- + # 1) Extract and copy the Milo neighborhood AnnData: + if "milo" not in mdata.mod: + raise KeyError('Cannot find "milo" modality in mdata. Did you run milo.build_nhood_graph()?') + nhood_adata: AnnData = mdata["milo"].T.copy() # transpose to get nhoods as “cells” + + # ------------------------------------------------------------------- + # 2) If annotation_key is provided, we skip the logFC logic and simply pull + # the annotation from mdata[adata_key].obs. We assume that the neighborhood + # IDs in nhood_adata.obs.index correspond to the same index in mdata[adata_key].obs. + if annotation_key is not None: + if adata_key not in mdata.mod: + raise KeyError(f'Cannot find "{adata_key}" modality in mdata.') + if annotation_key not in nhood_adata.obs.columns: + raise KeyError(f'Cannot find "{annotation_key}" column in mdata["{adata_key}"].obs.') + # Copy the annotation over to the neighborhood AnnData’s obs: + # We assume that nhood_adata.obs.index (e.g. neighborhood IDs) also appear + # as an index in mdata[adata_key].obs so we can simply reindex. + annots = mdata[adata_key].T.obs[annotation_key] + # Subset / align: + if not all(idx in annots.index for idx in nhood_adata.obs.index): + missing = set(nhood_adata.obs.index) - set(annots.index) + raise KeyError(f"The following neighborhood IDs are not found in mdata['{adata_key}'].obs: {missing}") + nhood_adata.obs["graph_color"] = annots.reindex(nhood_adata.obs.index).values + + # We do not filter by logFC or FDR in this mode; we just plot all neighborhoods. + # Sorting: if you want to put ‘NaN’ or a particular annotation at the bottom, + # you could sort by graph_color, but for simplicity we’ll plot in dataset order. + ordered = list(nhood_adata.obs.index) + nhood_adata = nhood_adata[ordered] + + # We no longer need “abs_logFC” or SpatialFDR logic, so skip to plotting. + vmax = None + vmin = None + + # Call scanpy’s embedding plot: + fig = sc.pl.embedding( + nhood_adata, + basis="X_milo_graph", + color="graph_color", + cmap=color_map or "tab20", # default to a discrete palette if none provided + size=nhood_adata.obs["Nhood_size"] * min_size, + edges=plot_edges, + neighbors_key="nhood", + sort_order=False, + frameon=False, + title=f"{annotation_key}", + palette=palette, + ax=ax, + show=False, + **kwargs, + ) + + if return_fig: + return fig + plt.show() + return None + + # ------------------------------------------------------------------- + # 3) Otherwise, annotation_key is None → we do the original logFC‐based coloring: + if "Nhood_size" not in nhood_adata.obs.columns: + raise KeyError( + 'Cannot find "Nhood_size" column in nhood_adata.obs; please run milo.build_nhood_graph() first.' + ) + if "logFC" not in nhood_adata.obs.columns or "SpatialFDR" not in nhood_adata.obs.columns: + raise KeyError( + 'Cannot find "logFC" / "SpatialFDR" columns in nhood_adata.obs; please run milo.da_nhoods() first.' + ) + + # Copy logFC into graph_color, then mask out nonsignificant / small‐effect neighborhoods: + nhood_adata.obs["graph_color"] = nhood_adata.obs["logFC"] + nhood_adata.obs.loc[nhood_adata.obs["SpatialFDR"] > alpha, "graph_color"] = np.nan + nhood_adata.obs["abs_logFC"] = np.abs(nhood_adata.obs["logFC"]) + nhood_adata.obs.loc[nhood_adata.obs["abs_logFC"] < min_logFC, "graph_color"] = np.nan + + # Plot order: neighborhoods with large |logFC| on top + nhood_adata.obs.loc[nhood_adata.obs["graph_color"].isna(), "abs_logFC"] = np.nan + ordered = nhood_adata.obs.sort_values("abs_logFC", na_position="first").index + nhood_adata = nhood_adata[ordered] + + # Determine symmetric color limits: + vmax = np.nanmax([nhood_adata.obs["graph_color"].max(), -nhood_adata.obs["graph_color"].min()]) + vmin = -vmax + + # Finally, call scanpy to draw the embedding: + fig = sc.pl.embedding( + nhood_adata, + basis="X_milo_graph", + color="graph_color", + cmap=color_map or "RdBu_r", + size=nhood_adata.obs["Nhood_size"] * min_size, + edges=plot_edges, + neighbors_key="nhood", + sort_order=False, + frameon=False, + title=title, + vmax=vmax, + vmin=vmin, + palette=palette, + ax=ax, + show=False, + **kwargs, + ) + + if return_fig: + return fig + plt.show() + return None + @_doc_params(common_plot_args=doc_common_plot_args) def plot_nhood( # pragma: no cover # noqa: D417 self, @@ -1138,3 +1334,1410 @@ def plot_nhood_counts_by_cond( # pragma: no cover # noqa: D417 plt.show() return None + + def _group_nhoods_from_adjacency( + self, + adjacency: spmatrix, + da_res: pd.DataFrame, + is_da: np.ndarray, + merge_discord: bool = False, + overlap: int = 1, + max_lfc_delta: float | None = None, + subset_nhoods: list | np.ndarray | None = None, + ) -> np.ndarray: + """Group neighborhoods using filtered adjacency and Louvain clustering. + + Filters the neighborhood adjacency matrix based on overlap, DA agreement, + and logFC similarity, then performs Louvain clustering on the resulting graph. + + Args: + adjacency: Sparse square matrix (shape: N × N) containing overlap counts between neighborhoods. + da_res: DataFrame of shape (N,), containing columns "SpatialFDR" and "logFC". + is_da: Boolean array of length N; True where a neighborhood is differentially abundant. + merge_discord: If False, remove edges between DA neighborhoods with opposite logFC signs. + Defaults to False. + overlap: Minimum overlap count required to retain an edge. Defaults to 1. + max_lfc_delta: If set, removes edges where the absolute difference in logFC exceeds this threshold. + Defaults to None (no filtering). + subset_nhoods: Optional subsetting of neighborhoods. Can be one of: + - Boolean mask of length N + - List/array of integer indices + - List/array of neighborhood names matching `da_res.index` + + Returns: + np.ndarray: Array of string cluster labels, of length equal to the number of selected neighborhoods. + These correspond to rows of `da_res` after subsetting. + """ + # 1) Optional subsetting of neighborhoods --------------------------------------------------- + # We allow subset_nhoods to be a boolean mask, a list of integer indices, or a list of names. + if subset_nhoods is not None: + # 1) boolean‐mask case first + if isinstance(subset_nhoods, pd.Series | np.ndarray) and subset_nhoods.dtype == bool: + if len(subset_nhoods) != adjacency.shape[0]: + raise ValueError("Boolean subset_nhoods must have length = number of neighborhoods.") + mask = np.asarray(subset_nhoods, dtype=bool) + + # 2) integer‐index or name list next + elif isinstance(subset_nhoods, list | np.ndarray): + arr = np.asarray(subset_nhoods) + # integer indices? + if np.issubdtype(arr.dtype, np.integer): + mask = np.zeros(adjacency.shape[0], dtype=bool) + mask[arr.astype(int)] = True + # name list? + else: + names = da_res.index.to_numpy(dtype=str) + mask = np.isin(names, arr.astype(str)) + + else: + raise ValueError("subset_nhoods must be a boolean mask, a list of indices, or a list of names.") + + adjacency = adjacency[mask, :][:, mask] + da_res = da_res.loc[mask].copy() + is_da = is_da[mask] + + M = adjacency.shape[0] + if da_res.shape[0] != M or is_da.shape[0] != M: + raise ValueError("Length of da_res and is_da must match adjacency dimension after subsetting.") + + # 2) Convert adjacency to CSR (if not already) and then to COO for a flat edge list ---------------- + adjacency = csr_matrix(adjacency) if not issparse(adjacency) else adjacency.tocsr() + + Acoo = adjacency.tocoo() + rows = Acoo.row # array of length E = number of nonzero edges + cols = Acoo.col + data = Acoo.data # the actual overlap counts + + # 3) Precompute logFC and sign arrays ------------------------------------------------------------------- + lfc_vals = da_res["logFC"].values # shape = (M,) + signs = np.sign(lfc_vals) # sign(lfc_i), shape = (M,) + + # 4) Build Boolean masks (length E) for each filter ------------------------------------------------------ + + # 4.1) “Discord” filter: if merge_discord=False, drop any edge (i,j) where both i,j are DA + # AND sign(lfc_i) * sign(lfc_j) < 0 (opposite signs). + if merge_discord: + keep_discord = np.ones_like(data, dtype=bool) + else: + # For each edge k at (i=rows[k], j=cols[k]), check if both are DA AND signs differ + is_da_rows = is_da[rows] # True if endpoint‐i is DA + is_da_cols = is_da[cols] # True if endpoint‐j is DA + sign_rows = signs[rows] + sign_cols = signs[cols] + + # discord_pair[k] = True if both DA and (signs multiply < 0) + discord_pair = (is_da_rows & is_da_cols) & ((sign_rows * sign_cols) < 0) + keep_discord = ~discord_pair + + # 4.2) “Overlap” filter: drop any edge whose current weight < overlap + keep_overlap = np.ones_like(data, dtype=bool) if overlap <= 1 else data >= overlap + + # 4.3) “Δ logFC” filter: drop any edge where |lfc_i - lfc_j| > max_lfc_delta + if max_lfc_delta is None: + keep_lfc = np.ones_like(data, dtype=bool) + else: + # Compute |lfc_vals[rows] - lfc_vals[cols]| vectorized + lfc_edge_diffs = np.abs(lfc_vals[rows] - lfc_vals[cols]) + keep_lfc = lfc_edge_diffs <= max_lfc_delta + + # 5) Combine all masks into a single “keep” mask ---------------------------------------------------------------- + keep_mask = keep_discord & keep_overlap & keep_lfc + + # 6) Rebuild a new, pruned adjacency in COO form (only edges where keep_mask=True) -------------------------- + new_rows = rows[keep_mask] + new_cols = cols[keep_mask] + new_data = data[keep_mask] + + # If you want an unweighted graph (just connectivity), you could do `new_data = np.ones_like(new_rows)`. + # But to mirror MiloR exactly, we preserve the original overlap counts until the final binarization. + pruned_adj = coo_matrix((new_data, (new_rows, new_cols)), shape=(M, M)).tocsr() + + # 7) Binarize: every surviving edge → 1, then convert to CSR ---------------------------------------------------------------- + pruned_adj = (pruned_adj > 0).astype(int).tocsr() + + # 8) Build an igraph from the final adjacency -------------------------------------------------------------------------------- + # We can use scanpy’s utility to convert a sparse (0/1) matrix to igraph. + # Issue with dematrix after subsetting adjacency matrix: + # dematrix in sc._utils.get_igraph_from_adjacency does not convert to dense numpy matrix. + # Trying direct conversion to igraph: + g = sc._utils.get_igraph_from_adjacency(pruned_adj, directed=False) + + # 9) Run Louvain (multilevel) clustering on the unweighted graph ---------------------------------------------------------------- + # By not providing a “weights” argument, igraph treats every edge as weight=1. + clustering = g.community_multilevel(weights=None) + labels = np.array(clustering.membership, dtype=str) # length = M, dtype = 'str' + + # 10) Return the cluster labels array (strings), in the same order as da_res.index --------------------------------------- + # If subset_nhoods was not None, these labels correspond to rows where mask=True. + return labels + + def group_nhoods( + self, + data: Any, + key: str | None = "milo", + da_res: pd.DataFrame | None = None, + da_fdr: float = 0.1, + overlap: int = 1, + max_lfc_delta: float | None = None, + merge_discord: bool = False, + subset_nhoods: (pd.Series | np.ndarray | list[int] | list[str] | None) = None, + ) -> None: + """Cluster Milo neighborhoods into groups (Louvain) and annotate `adata.var`. + + A Python re-implementation of MiloR’s `groupNhoods()`. Given an AnnData (or a modality + within a MuData) containing precomputed neighborhood connectivity and differential + abundance results, compute connected components of DA neighborhoods (with optional + filters) and write back a categorical `"nhood_groups"` column in `adata.var`. + + Args: + data: AnnData or MuData object. If MuData, `key` selects the modality to use. + key: Modality name within `data` when using a MuData. Defaults to "milo". + da_res: DataFrame of neighborhood‐level results with index matching + `adata.var.index` and columns "SpatialFDR" and "logFC". If None, uses `adata.var`. + da_fdr: Threshold for SpatialFDR below which neighborhoods are considered + differentially abundant and included in the clustering. Defaults to 0.1. + overlap: Minimum adjacency weight to retain an edge; edges with weight below + this value are dropped. Defaults to 1. + max_lfc_delta: Maximum allowed absolute difference in logFC between two + neighborhoods; edges exceeding this value are dropped. Defaults to None. + merge_discord: If False, edges between two DA neighborhoods whose logFC + signs disagree are dropped. Defaults to False. + subset_nhoods: Boolean mask, list/array of integer indices, or list of + neighborhood ID strings to restrict clustering. Defaults to None. + + Returns: + Updates `adata.var["nhood_groups"]` in place with each neighborhood’s + Louvain group label. Neighborhoods not included in `subset_nhoods` + will have `pd.NA` in that column. + + Examples: + >>> import perturbpy as pt + >>> import scanpy as sc + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.da_nhoods(mdata, design="~label") + >>> milo.group_nhoods(mdata) + """ + if isinstance(data, AnnData): + adata = data + elif isinstance(data, MuData): + if key is None: + raise ValueError("If `data` is a MuData object, `key` must be specified.") + adata = data[key] + else: + raise ValueError("`data` must be an AnnData or MuData object.") + + # 1) Validate input --------------------------------------------------------------------------------------------- + if not isinstance(adata, AnnData): + raise ValueError("`adata` must be an AnnData object.") + + # 2) Get or check `da_res` -------------------------------------------------------------------------------------- + if da_res is None: + da_res = adata.var + # If user passed their own da_res, ensure indexes match + elif not da_res.index.equals(adata.var.index): + raise ValueError("`da_res` index must match `adata.var.index`.") + + # Ensure required columns exist + if "SpatialFDR" not in da_res.columns or "logFC" not in da_res.columns: + raise ValueError("`da_res` (adata.var) must contain columns 'SpatialFDR' and 'logFC'.") + + # 3) Identify “DA” neighborhoods by FDR cutoff ------------------------------------------------------------------- + fdr_values = da_res["SpatialFDR"].values + if np.all(pd.isna(fdr_values)): + raise ValueError("All `SpatialFDR` values are NA; cannot determine DA neighborhoods.") + is_da = fdr_values < da_fdr + + n_da = int(is_da.sum()) + if n_da == 0: + raise ValueError(f"No DA neighborhoods found at FDR < {da_fdr}.") + + # 4) Extract adjacency ------------------------------------------------------------------------------------------ + if "nhood_connectivities" not in adata.varp: + raise KeyError("`adata.varp` does not contain 'nhood_connectivities'. Did you run buildNhoodGraph?") + adjacency = adata.varp["nhood_connectivities"] + + # 5) Call core worker to get string labels ---------------------------------------------------------------------- + labels = self._group_nhoods_from_adjacency( + adjacency=adjacency, + da_res=da_res, + is_da=is_da, + merge_discord=merge_discord, + overlap=overlap, + max_lfc_delta=max_lfc_delta, + subset_nhoods=subset_nhoods, + ) + + # 6) Write results back into `adata.var["NhoodGroup"]` ----------------------------------------------------------- + N_full = adata.var.shape[0] + out = np.array([pd.NA] * N_full, dtype=object) + + if subset_nhoods is None: + # no subsetting: every label goes into the full array + out[:] = labels + else: + # 1) Boolean‐mask case first + if isinstance(subset_nhoods, pd.Series | np.ndarray) and getattr(subset_nhoods, "dtype", None).kind == "b": + if len(subset_nhoods) != N_full: + raise ValueError("Boolean subset_nhoods must have length = number of neighborhoods.") + mask_idx = np.asarray(subset_nhoods, dtype=bool) + + # 2) Integer‐index or name‐list next + elif isinstance(subset_nhoods, list | np.ndarray): + arr = np.asarray(subset_nhoods) + if np.issubdtype(arr.dtype, np.integer): + mask_idx = np.zeros(N_full, dtype=bool) + mask_idx[arr.astype(int)] = True + else: + names = adata.var.index.to_numpy(dtype=str) + mask_idx = np.isin(names, arr.astype(str)) + + else: + raise ValueError("`subset_nhoods` must be a boolean mask, a list of indices, or a list of names.") + + # 3) Place the M labels back into the N-length output + out[mask_idx] = labels + + adata.var["nhood_groups"] = out + + def _nhood_labels_to_cells_last_wins( + self, + mdata: MuData, + nhood_group_obs: str = "nhood_groups", + subset_nhoods: list | np.ndarray | None = None, + ) -> None: + """Map neighborhood group labels back to single cells (last group wins). + + Assigns a neighborhood group label to each cell based on the neighborhoods + it belongs to. If a cell belongs to multiple neighborhoods with different + labels, the first non-missing label (by category order) is used. Operates + in-place on `mdata["rna"].obs["nhood_groups"]`. + + Args: + mdata: MuData object with: + - mdata["milo"]: contains `.var[nhood_group_obs]` and neighborhood indices. + - mdata["rna"]: must have `.obsm["nhoods"]` sparse binary matrix of shape (cells × neighborhoods). + nhood_group_obs: Column name in `mdata["milo"].var` holding the neighborhood group labels. + Must be categorical or convertible to categorical. Defaults to "nhood_groups". + subset_nhoods: Optional subset of neighborhood indices to consider. Can be: + - A boolean mask, + - A list/array of integer indices, + - A list/array of string IDs matching `mdata["milo"].var.index`. + + Returns: + None: The results are written to `mdata["rna"].obs["nhood_groups"]` in place. + """ + nhood_mat = mdata["rna"].obsm["nhoods"] + + da_res = mdata["milo"].var.copy() + ### update for categorical nhood_group_obs to control order of levels + # This turns the nhood_group_obs into a CategoricalDtype if it isn't already + col = nhood_group_obs + # if it isn’t already a CategoricalDtype, cast it + if not isinstance(da_res[col].dtype, pd.api.types.CategoricalDtype): + da_res[col] = da_res[col].astype("category") + + nhood_mat = AnnData(X=nhood_mat) + nhood_mat.obs_names = mdata["rna"].obs_names + nhood_mat.var_names = [str(i) for i in range(nhood_mat.shape[1])] + + nhs_da_gr = da_res[nhood_group_obs].copy() + nhs_da_gr.index = da_res.index.to_numpy() + + # We want to drop NAs from the nhood_group_obs column, not the whole DataFrame + # nhood_gr = da_res.dropna()[nhood_group_obs].unique() + nhood_gr = da_res[nhood_group_obs].cat.categories + + nhs = nhood_mat.copy() + + # if(!is.null(subset.nhoods)){ + # nhs <- nhs[,subset.nhoods] + ## # ## Remove cells out of neighbourhoods of interest + # # nhs <- nhs[rowSums(nhs) > 0,] + # } + + if subset_nhoods is not None: + nhs = nhs[:, subset_nhoods] + nhs = nhs[np.asarray(nhs.X.sum(1)).ravel() > 0, :].copy() + + fake_meta = pd.DataFrame( + { + "CellID": nhs.obs_names[(np.asarray(nhs.X.sum(1).flatten()).ravel() != 0)], + # "Nhood_Group": [np.nan for _ in range((np.asarray(nhs.X.sum(1).flatten()).ravel() != 0).sum())], + "Nhood_Group": [pd.NA for _ in range((np.asarray(nhs.X.sum(1).flatten()).ravel() != 0).sum())], + } + ) + fake_meta.index = fake_meta["CellID"].copy() + + for i in range(len(nhood_gr)): + cur_nh_group = nhood_gr[i] + + nhood_x = nhs_da_gr == cur_nh_group + nhood_x = nhood_x[nhood_x] + nhood_x = nhood_x.index + nhood_x = np.asarray(nhood_x) + + nhs = nhs[nhs.X.sum(1) > 0, :].copy() + + mask = np.asarray(nhs[:, nhood_x].X.sum(1)).ravel() > 0 + nhood_gr_cells = nhs.obs_names[mask] + + fake_meta.loc[nhood_gr_cells, "Nhood_Group"] = np.where( + (fake_meta.loc[nhood_gr_cells, "Nhood_Group"]).isna(), + nhood_gr[i], + pd.NA, + ) + + mdata["rna"].obs["nhood_groups"] = pd.NA + mdata["rna"].obs.loc[fake_meta.CellID.to_list(), "nhood_groups"] = fake_meta.Nhood_Group.to_numpy() + + def _get_cells_in_nhoods( + self, + adata: AnnData, + nhood_ids: np.ndarray | list, + ) -> None: + """Compute number of neighborhood memberships per cell and store in `.obs`. + + For the selected neighborhoods, calculates how many of them each cell belongs to. + Stores the result in `adata.obs["in_nhoods"]`. + + Args: + adata: AnnData object with `.obsm["nhoods"]`, a binary matrix of shape (cells × neighborhoods). + nhood_ids: List or array of neighborhood indices to include in the count. + + Returns: + None: The result is stored in-place in `adata.obs["in_nhoods"]`. + """ + if not isinstance(nhood_ids, np.ndarray): + nhood_ids = np.asarray(nhood_ids, dtype=int) + in_nhoods = np.array(adata.obsm["nhoods"][:, nhood_ids].sum(1)) + adata.obs["in_nhoods"] = in_nhoods + + def _nhood_labels_to_cells_exclude_overlaps( + self, + mdata: MuData, + nhood_group_obs: str = "nhood_groups", + min_n_nhoods: int = 3, + ) -> None: + """Assign cells to a dominant neighborhood group, excluding ambiguous overlaps. + + For each neighborhood group, compute how many neighborhoods each cell belongs to. + Then, assign each cell to the group with the most memberships, if that count exceeds + `min_n_nhoods`. All other cells are left unassigned (NaN). + + Args: + mdata: MuData object with: + - `mdata["milo"].var[nhood_group_obs]`: categorical group labels for neighborhoods. + - `mdata["rna"].obsm["nhoods"]`: binary matrix (cells × neighborhoods) indicating memberships. + nhood_group_obs: Name of the column in `mdata["milo"].var` containing group labels. + Defaults to "nhood_groups". + min_n_nhoods: Minimum number of neighborhoods from the same group a cell must belong to + in order to be assigned. Defaults to 3. + + Returns: + None: Results are written in-place to `mdata["rna"].obs["nhood_groups"]`. + """ + groups = mdata["milo"].var[nhood_group_obs].dropna().unique() + for g in groups: + nhoods_oi = mdata["milo"].var_names[mdata["milo"].var[nhood_group_obs] == g] + self._get_cells_in_nhoods(mdata["rna"], nhoods_oi) + mdata["rna"].obs[f"in_nhoods_{g}"] = mdata["rna"].obs["in_nhoods"].copy() + + ## Find most representative group (cell belongs to mostly to neighbourhoods of that group) + mdata["rna"].obs["nhood_groups"] = np.nan + mdata["rna"].obs["nhood_groups"] = mdata["rna"].obs[[f"in_nhoods_{g}" for g in groups]].idxmax(1) + ## Keep only if cell is in at least min_n_nhoods nhoods of the same group + mdata["rna"].obs.loc[ + ~(mdata["rna"].obs[[f"in_nhoods_{g}" for g in groups]] > min_n_nhoods).any(axis=1), "nhood_groups" + ] = np.nan + ### Remove the in_nhoods in nhood_groups columns + mdata["rna"].obs["nhood_groups"] = ( + mdata["rna"].obs["nhood_groups"].apply(lambda x: x.split("_")[-1] if isinstance(x, str) else x) + ) + mdata["rna"].obs["nhood_groups"] = mdata["rna"].obs["nhood_groups"].str.removeprefix("in_nhoods_") + + def annotate_cells_from_nhoods( + self, + mdata: MuData, + nhood_group_obs: str = "nhood_groups", + subset_nhoods: list[str] | None = None, + min_n_nhoods: int = 3, + mode: Literal["last_wins", "exclude_overlaps"] = "last_wins", + ) -> None: + """Assign neighborhood group labels to cells based on neighborhood membership. + + This function annotates cells in `mdata["rna"].obs` using group labels from + `mdata["milo"].var[nhood_group_obs]`. Supports two modes for resolving overlaps: + - "last_wins": Assign the last matching group label (default; mimics MiloR behavior). + - "exclude_overlaps": Assign only if the cell belongs to a minimum number of neighborhoods + from a single group. + + Args: + mdata: MuData object with: + - `mdata["milo"].var[nhood_group_obs]`: categorical group labels. + - `mdata["rna"].obsm["nhoods"]`: binary matrix of cell–neighborhood memberships. + nhood_group_obs: Column name in `mdata["milo"].var` with group labels. Defaults to "nhood_groups". + subset_nhoods: Optional list of neighborhood IDs to restrict annotation. If None, all neighborhoods are used. + min_n_nhoods: Minimum number of neighborhoods from the same group a cell must belong to in order to be assigned + (only used in mode `"exclude_overlaps"`). Defaults to 3. + mode: Strategy for resolving overlapping group assignments. One of: + - `"last_wins"`: Assign label from last matching neighborhood. + - `"exclude_overlaps"`: Assign only if group dominates cell’s memberships. + + Returns: + Updates `mdata["rna"].obs["nhood_groups"]` with the assigned labels in place. + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.da_nhoods(mdata, design="~label") + >>> milo.group_nhoods(mdata) + >>> milo.annotate_cells_from_nhoods(mdata, mode="last_wins") + >>> milo.annotate_cells_from_nhoods(mdata, mode="exclude_overlaps", min_n_nhoods=3) + """ + if mode == "last_wins": + self._nhood_labels_to_cells_last_wins(mdata, nhood_group_obs, subset_nhoods) + elif mode == "exclude_overlaps": + self._nhood_labels_to_cells_exclude_overlaps(mdata, nhood_group_obs, min_n_nhoods) + else: + raise ValueError(f"Unknown mode '{mode}'. Use 'last_wins' or 'exclude_overlaps'.") + + # def get_mean_expression( + # self, + # adata: AnnData, + # groupby: str, + # var_names: list[str], + # ) -> pd.DataFrame: + # """Compute the mean expression of selected genes stratified by a categorical grouping. + + # Args: + # adata: AnnData object containing the expression matrix in `X` and categorical metadata in `obs`. + # groupby: Name of the column in `adata.obs` to group cells by. + # var_names: List of variable (gene) names for which to compute the mean expression. + + # Returns: + # mean_df: A pandas DataFrame of shape (len(var_names), n_groups) where + # - rows are the genes in `var_names` + # - columns are the unique categories in `adata.obs[groupby]` + # - each entry is the average count of that gene over all cells in the corresponding group. + # """ + # # 1) Subset the matrix to just the columns (genes) in var_names: + # subX = adata[:, var_names].X.copy() # shape: (n_cells, n_genes) + + # # 2) Build a one‐hot (dummy) matrix of shape (n_cells, n_groups): + # groups = pd.get_dummies(adata.obs[groupby], drop_first=False) + # # groups.values is (n_cells, n_groups). groups.sum() is a Series: number of cells per group. + # n_per_group = groups.sum().astype(float) # length = n_groups + + # # 3) Compute Σ_counts_{gene i, group j} by matrix‐multiplication: + # # - If subX is sparse, convert to a CSR; otherwise treat as dense. + # if issparse(subX): + # subX = csr_matrix(subX) + # sum_counts = subX.T.dot(csr_matrix(groups.values)) # shape (n_genes, n_groups), sparse + # sum_counts = sum_counts.toarray() # convert to dense (n_genes, n_groups) + # else: + # # dense case: subX is (n_cells, n_genes), so subX.T is (n_genes, n_cells), + # # dot with (n_cells, n_groups) → (n_genes, n_groups) + # sum_counts = subX.T.dot(groups.values) + + # # 4) Divide each column (group) by its total cell count to get means: + # # We want mean_counts[i, j] = sum_counts[i, j] / n_per_group[j]. + # # n_per_group.values is shape (n_groups,), so broadcasting works. + # mean_mat = sum_counts / n_per_group.values[np.newaxis, :] + + # # 5) Build a DataFrame, indexed by var_names, columns = groups.columns + # mean_df = pd.DataFrame(mean_mat, index=var_names, columns=groups.columns) + # return mean_df + + def _run_edger_contrasts( + self, + pdata: AnnData, + nhood_group_obs: str, + *, + formula: str, + group_to_compare: str | None = None, + baseline: str | None = None, + subset_samples: list[str] | None = None, + ) -> pd.DataFrame: + """Run edgeR QLF tests on pseudobulk data using specified contrasts. + + Performs differential expression analysis using edgeR's quasi-likelihood + F-tests on pseudobulked expression data. Supports either a user-specified + two-level contrast (`baseline` vs `group_to_compare`), or one-vs-rest testing + across all groups in `pdata.obs[nhood_group_obs]`. + + Args: + pdata: AnnData object with pseudobulked expression data in `.X` and + sample-level annotations in `.obs`. + nhood_group_obs: Name of the `.obs` column used to define group membership for contrast. + formula: R-style design formula (e.g., `"~ group"`). Used to generate design matrices in edgeR. + group_to_compare: Name of the group to compare (e.g., `"treated"`). + If provided along with `baseline`, a two-group test is run. + baseline: Reference group (e.g., `"control"`) for the two-group contrast. + subset_samples: Optional list of sample names to subset `pdata` before analysis. + + Returns: + :class:`pandas.DataFrame`: Differential expression results with columns: + - `"variable"`: gene/feature name + - `"log_fc"`: log-fold change estimate + - `"p_value"`: raw p-value + - `"adj_p_value"`: multiple-testing corrected p-value + - `"group"` (optional): group name (only present in one-vs-rest mode) + + Raises: + ValueError: If contrast groups are not present in the data or input is malformed. + + Example: + >>> de_df = milo._run_edger_contrasts(pdata, "condition", formula="~ condition", + >>> group_to_compare="treated", baseline="control") + """ + if not _is_counts(pdata.X): + raise ValueError("`pdata.X` appears to be raw counts, but this function expects continuous expression.") + + edger, limma, rstats, rbase = self._setup_rpy2() + import rpy2.robjects as ro + from rpy2.robjects import IntVector, StrVector, baseenv, numpy2ri, pandas2ri + from rpy2.robjects.conversion import localconverter + + # 2) Build a pandas DataFrame for sample‐level covariates + + # 6) Single‐contrast vs one‐vs‐rest + results_list: list[Any] = [] + + # If a specific two‐level contrast was given: + if group_to_compare is not None and baseline is not None: + # subset pdata to only those two groups + pdata = pdata[pdata.obs[nhood_group_obs].isin([baseline, group_to_compare])].copy() + if pdata.shape[0] == 0: + raise ValueError(f"No samples found with {nhood_group_obs} in [{baseline}, {group_to_compare}].") + + ### build R count matrix + count_mat = pdata.X.toarray().T if hasattr(pdata.X, "toarray") else np.asarray(pdata.X).T + with localconverter(ro.default_converter + numpy2ri.converter): + rmat = numpy2ri.py2rpy(count_mat) + + r_colnames = StrVector(np.asarray(pdata.obs_names)) + r_rownames = StrVector(np.asarray(pdata.var_names)) + dim_list = ro.r.list(r_rownames, r_colnames) + + assign_dim = baseenv["dimnames<-"] + rmat = assign_dim(rmat, dim_list) + + # Build the DGEList + dge = edger.DGEList(counts=rmat) + + # build R model matrix from sample_obs, setting levels of nhood_group_obs to baseline and group_to_compare + sample_obs = pdata.obs.copy() + if group_to_compare is not None and baseline is not None: + # If a specific two‐level contrast was given, subset to those samples only + sample_obs[nhood_group_obs] = pd.Categorical( + sample_obs[nhood_group_obs].values, categories=[baseline, group_to_compare] + ) + + with localconverter(ro.default_converter + pandas2ri.converter): + robs = pandas2ri.py2rpy(sample_obs) + design_r = rstats.model_matrix(rstats.as_formula(formula), robs) + + # Fit the quasi‐likelihood model + dge = edger.calcNormFactors(dge, method="TMM") + + fit = edger.glmQLFit(dge, design_r, robust=True) + + # Now run QLF test with that contrast + qlf = edger.glmQLFTest(fit, coef=nhood_group_obs + group_to_compare) + top = edger.topTags(qlf, sort_by="none", n=np.inf)[0] + # Convert top (an R data.frame) to pandas + with localconverter(ro.default_converter + pandas2ri.converter): + top_df = pandas2ri.rpy2py(top) + + # Clean up column names (they come as “logFC”, “PValue”, “FDR”) + top_df = top_df.rename(columns={"FDR": "adj_p_value", "PValue": "p_value", "logFC": "log_fc"}) + top_df = top_df.reset_index().rename(columns={"index": "variable"}) + return top_df + # results_list.append(top_df[["variable", "logFC", "PValue", "adj_PValue"]]) + + else: + ### build R count matrix + count_mat = pdata.X.toarray().T if hasattr(pdata.X, "toarray") else np.asarray(pdata.X).T + with localconverter(ro.default_converter + numpy2ri.converter): + rmat = numpy2ri.py2rpy(count_mat) + + r_colnames = StrVector(np.asarray(pdata.obs_names)) + r_rownames = StrVector(np.asarray(pdata.var_names)) + dim_list = ro.r.list(r_rownames, r_colnames) + + assign_dim = baseenv["dimnames<-"] + rmat = assign_dim(rmat, dim_list) + + # Build the DGEList + dge = edger.DGEList(counts=rmat) + sample_obs = pdata.obs.copy() + + col = sample_obs[nhood_group_obs] + if isinstance(col, pd.api.types.CategoricalDtype): + unique_groups = col.cat.categories.tolist() + else: + unique_groups = col.unique().tolist() + + results_list = [] + + for grp in unique_groups: + # build group‐specific design matrix + tmp_obs = pdata.obs.copy() + + tmp_obs[nhood_group_obs] = [x if x == grp else "rest" for x in tmp_obs[nhood_group_obs]] + tmp_obs[nhood_group_obs] = pd.Categorical(tmp_obs[nhood_group_obs].values, categories=["rest", grp]) + with localconverter(ro.default_converter + pandas2ri.converter): + robs = pandas2ri.py2rpy(tmp_obs) + design_r = rstats.model_matrix(rstats.as_formula(formula), robs) + + # 6) Build a DGEList for this subset (or reuse dge_full but safer to make a fresh one) + dge = edger.calcNormFactors(dge, method="TMM") + fit_sub = edger.glmQLFit(dge, design_r, robust=True) + + # 8) QLF test on that contrast + qlf_sub = edger.glmQLFTest(fit_sub, coef=nhood_group_obs + grp) + top_sub = edger.topTags(qlf_sub, sort_by="none", n=np.inf)[0] + + with localconverter(ro.default_converter + pandas2ri.converter): + top_df_sub = pandas2ri.rpy2py(top_sub) + + top_df_sub = top_df_sub.rename(columns={"FDR": "adj_p_value", "PValue": "p_value", "logFC": "log_fc"}) + top_df_sub = top_df_sub.reset_index().rename(columns={"index": "variable"}) + top_df_sub["group"] = grp + + results_list.append(top_df_sub[["variable", "log_fc", "p_value", "adj_p_value", "group"]]) + + # 9) Concatenate and return + final_df = pd.concat(results_list, ignore_index=True) + return final_df + + def _run_pydeseq2_contrasts( + self, + pdata: AnnData, + nhood_group_obs: str, + *, + formula: str, + group_to_compare: str | None = None, + baseline: str | None = None, + alpha: float = 0.05, + quiet: bool = True, + ) -> pd.DataFrame: + """Run PyDESeq2 differential testing on pseudobulked AnnData using a design formula. + + Supports either a two-level contrast (`group_to_compare` vs `baseline`) or one-vs-rest + comparisons for all levels of a categorical column in `.obs`. Results are returned + as a tidy `DataFrame` compatible with downstream analysis. + + Args: + pdata: Pseudobulk `AnnData` object with expression matrix in `.X` and + covariates in `.obs`, including `nhood_group_obs`. + nhood_group_obs: Name of the `.obs` column to use for contrast groups. + formula: R-style design formula (e.g., `"~ batch + group"`), passed directly to PyDESeq2. + group_to_compare: Name of the group to test against `baseline`. If None, + one-vs-rest mode is triggered. + baseline: Name of the baseline group for contrast. Must be specified if `group_to_compare` is given. + alpha: FDR threshold passed to PyDESeq2's `DeseqStats`. Defaults to 0.05. + quiet: If True, suppresses progress messages from PyDESeq2. Defaults to True. + + Returns: + :class:`pandas.DataFrame`: If `group_to_compare` and `baseline` are specified, returns a + single contrast result with columns: + - `"variable"`: feature name + - `"log_fc"`: log2 fold change + - `"p_value"`: raw p-value + - `"adj_p_value"`: FDR-corrected p-value + + If no contrast is specified, performs one-vs-rest for each group and returns + a concatenated DataFrame with the same columns plus: + - `"group"`: the group tested against all others + + Raises: + ImportError: If `pydeseq2` is not installed. + ValueError: If only one of `group_to_compare` or `baseline` is provided. + + Example: + >>> de_df = milo._run_pydeseq2_contrasts( + ... pdata, + ... nhood_group_obs="condition", + ... formula="~ condition", + ... group_to_compare="treated", + ... baseline="control", + ... ) + """ + if find_spec("pydeseq2") is None: + raise ImportError("pydeseq2 is required but not installed. Install with: pip install pydeseq2") + from pydeseq2.dds import DeseqDataSet + from pydeseq2.ds import DeseqStats + + # Basic check: if both group_to_compare & baseline are provided, do just that contrast + if (group_to_compare is not None) ^ (baseline is not None): + raise ValueError("You must supply either both `group_to_compare` and `baseline`, or neither.") + + # 1) Single contrast branch + if group_to_compare is not None and baseline is not None: + # 1a) Build the DESeqDataSet using exactly the provided `formula` + dds = DeseqDataSet(adata=pdata, design=formula, quiet=quiet) + dds.deseq2() + + # 1b) Run PyrDESeq2 with the single contrast + stat_res = DeseqStats( + dds, + contrast=[nhood_group_obs, group_to_compare, baseline], + alpha=alpha, + quiet=quiet, + ) + stat_res.summary() + + # 1c) Collect results into a pandas DataFrame + df = ( + pd.DataFrame(stat_res.results_df) + .rename( + columns={ + "log2FoldChange": "log_fc", + "pvalue": "p_value", + "padj": "adj_p_value", + } + ) + .sort_values("p_value") + .reset_index(names=["variable"]) + ) + return df + + # 2) One‐vs‐rest: get all levels of nhood_group_obs + col = pdata.obs[nhood_group_obs] + unique_groups = ( + col.cat.categories.tolist() if isinstance(col, pd.api.types.CategoricalDtype) else col.unique().tolist() + ) + + all_results = [] + for grp in unique_groups: + # 2a) Copy pdata so we can relabel group vs “rest” + tmp = pdata.copy() + + # 2b) Ensure “rest” is a valid category, then recode everything not == grp → "rest" + tmp.obs[nhood_group_obs] = tmp.obs[nhood_group_obs].cat.add_categories("rest") + tmp.obs[nhood_group_obs] = tmp.obs[nhood_group_obs].apply(lambda x, grp=grp: x if x == grp else "rest") + # Now tmp.obs[nhood_group_obs] has exactly two levels: grp and "rest" + + # 2c) Build DESeqDataSet on `tmp` using **the same** `formula` + # (The formula must reference nhood_group_obs so that “grp” vs “rest” is meaningful.) + dds = DeseqDataSet(adata=tmp, design=formula, quiet=quiet) + dds.deseq2() + + # 2d) Run PyrDESeq2 with contrast = [nhood_group_obs, grp, "rest"] + stat_res = DeseqStats( + dds, + contrast=[nhood_group_obs, grp, "rest"], + alpha=alpha, + quiet=quiet, + ) + stat_res.summary() + + # 2e) Extract results, rename, attach “group = grp” + df = ( + pd.DataFrame(stat_res.results_df) + .rename( + columns={ + "log2FoldChange": "log_fc", + "pvalue": "p_value", + "padj": "adj_p_value", + } + ) + .reset_index(names=["variable"]) + .assign(group=grp) + .sort_values("p_value") + ) + + all_results.append(df) + + # 3) Concatenate and return + final_df = pd.concat(all_results, ignore_index=True) + return final_df + + def _filter_by_expr_edger( + self, + pdata: AnnData, + formula: str, + **kwargs, + ) -> None: + """Filter low-expressed genes from a pseudobulk AnnData object using edgeR. + + This function uses `edgeR::filterByExpr()` via rpy2 to identify and retain + genes with sufficient expression for differential testing, based on the + provided design formula and expression thresholds. + + The filtering is performed in-place by subsetting `pdata.var`. + + Args: + pdata: Pseudobulk `AnnData` object with raw counts in `.X` and + covariates in `.obs`. + formula: R-style design formula (e.g., `"~ condition + batch"`), used to + compute the design matrix in edgeR. + **kwargs: Additional keyword arguments passed to `edgeR::filterByExpr()`. + Examples include `min.count`, `min.total.count`, etc. + + Returns: + None: The function modifies `pdata` in place by subsetting `pdata.var` + to include only retained genes. + """ + edger, _, rstats, rbase = self._setup_rpy2() + import rpy2.robjects as ro + from rpy2.robjects import numpy2ri, pandas2ri + from rpy2.robjects.conversion import localconverter + + counts = pdata.X + counts = counts.toarray().T if hasattr(counts, "toarray") else np.asarray(counts).T + with localconverter(ro.default_converter + numpy2ri.converter): + rcounts = numpy2ri.py2rpy(counts) + obs = pdata.obs + with localconverter(ro.default_converter + pandas2ri.converter): + robs = pandas2ri.py2rpy(obs) + rdesign = rstats.model_matrix(rstats.as_formula(formula), robs) + rkeep = edger.filterByExpr(rcounts, design=rdesign, **kwargs) + keep = list(rkeep) + + pdata._inplace_subset_var(keep) + + def _filter_highly_variable_scanpy( + self, + pdata: AnnData, + n_top_genes: int = 7500, + target_sum: float = 1e6, + **kwargs, + ) -> None: + """Filter highly variable genes from a pseudobulk AnnData using Scanpy. + + Normalizes and log-transforms raw count data if needed, then selects the + top `n_top_genes` most variable genes using `scanpy.pp.highly_variable_genes`. + Results are stored in-place by subsetting `pdata.var`. + + Args: + pdata: AnnData object with raw or normalized pseudobulk expression in `.X`. + n_top_genes: Number of top variable genes to retain. Defaults to 7500. + target_sum: Target total count for normalization (used only if `.X` is raw counts). + Defaults to 1e6. + **kwargs: Additional keyword arguments passed to `scanpy.pp.highly_variable_genes()`. + + Returns: + None: The function modifies `pdata` in place: + - Adds normalized expression to `pdata.layers["normalized"]` + - Filters `.var` to include only the top `n_top_genes` genes + """ + if _is_counts(pdata.X): + pdata.layers["normalized"] = pdata.X.copy() + sc.pp.normalize_total( + pdata, + target_sum=target_sum, + layer="normalized", + ) + sc.pp.log1p(pdata, layer="normalized") + else: + pdata.layers["normalized"] = pdata.X.copy() + + sc.pp.highly_variable_genes(pdata, layer="normalized", n_top_genes=n_top_genes, subset=True, **kwargs) + + def _filter_highly_variable_scran( + self, + pdata: AnnData, + n_top_genes: int, + ) -> None: + """Filter highly variable genes using R's scran and scuttle packages. + + If `pdata.X` contains raw counts, normalization is performed using + `logNormCounts()` from `scuttle`. Otherwise, the matrix is assumed + to be already log-normalized. + + The top `n_top_genes` most variable genes are selected using + `scran.modelGeneVar()` and `scran.getTopHVGs()` and used to subset + `pdata.var` in place. + + Args: + pdata: AnnData object containing pseudobulk expression matrix. + n_top_genes: Number of top highly variable genes to retain. + + Returns: + None: The function modifies `pdata` in place by subsetting `.var` + to contain only the selected HVGs. + """ + scran = self._try_import_bioc_library("scran") + scuttle = self._try_import_bioc_library("scuttle") + singlecellexperiment = self._try_import_bioc_library("SingleCellExperiment") + + import rpy2.robjects as ro + from rpy2.robjects import ListVector, numpy2ri, pandas2ri + from rpy2.robjects.conversion import localconverter + + counts = pdata.X + with localconverter(ro.default_converter + numpy2ri.converter): + rcounts = numpy2ri.py2rpy(counts.T) + obs = pdata.obs + var = pdata.var + + with localconverter(ro.default_converter + pandas2ri.converter): + robs = pandas2ri.py2rpy(obs) + rvar = pandas2ri.py2rpy(var) + + if _is_counts(counts): + sce = singlecellexperiment.SingleCellExperiment(ListVector({"counts": rcounts}), colData=robs, rowData=rvar) + sce = scuttle.logNormCounts(sce) + else: + sce = singlecellexperiment.SingleCellExperiment( + ListVector({"logcounts": rcounts}), colData=robs, rowData=rvar + ) + + dec = scran.modelGeneVar(sce) + hvgs = scran.getTopHVGs(dec, n=n_top_genes) + hvgs = list(hvgs) + + pdata._inplace_subset_var(hvgs) + + def find_nhood_group_markers( + self, + data: AnnData | MuData, + group_to_compare: str | None = None, + baseline: str | None = None, + nhood_group_obs: str = "nhood_groups", + sample_col: str = "sample", + covariates: Collection[str] | None = None, + key: str = "rna", + pseudobulk_function: str = "sum", + layer: str | None = None, + target_sum: float = 1e6, + n_top_genes: int = 7500, + filter_method: str | None = "scanpy", + var_names: Collection[str] | None = None, + de_method: Literal["pydeseq2", "edger"] = "pydeseq2", + quiet: bool = True, + alpha: float = 0.05, + use_eb: bool = False, + **kwargs, + ) -> pd.DataFrame: + """Perform differential expression analysis on neighborhood groups in a MuData or AnnData object. + + This function performs pseudobulk aggregation over neighborhood groupings and tests for differential + expression (DE) between groups using either `pydeseq2` or `edgeR`. + + The MuData must contain a modality (default `"rna"`) used for pseudobulk aggregation. + Group labels must be stored in `nhood_group_obs`, and sample labels in `sample_col`. + + If both `group_to_compare` and `baseline` are provided, a specific two-group contrast is tested. + Otherwise, one-vs-rest DE is performed for each level in `nhood_group_obs`. + + Notes: + - All NAs in `nhood_group_obs` are removed before pseudobulk aggregation. + - If annotating neighborhood groups manually, you can introduce NAs beforehand to exclude neighborhoods. + + Args: + data: A `MuData` or `AnnData` object. + group_to_compare: The group to compare (e.g. case) in a specific contrast. Must be in `nhood_group_obs`. + baseline: The baseline group (e.g. control). Must be in `nhood_group_obs` if `group_to_compare` is given. + nhood_group_obs: Column in `.obs` with neighborhood group labels. Must be categorical or string-typed. + sample_col: Column in `.obs` specifying sample identifiers. + covariates: Optional list of covariates to include in the DE model formula. + key: Name of modality in `MuData` used for pseudobulk aggregation (default `"rna"`). + pseudobulk_function: Aggregation function used for pseudobulk (either `"sum"` or `"mean"`). + layer: Optional layer to use for aggregation. Defaults to `X` if not provided. + target_sum: Used for Scanpy-based normalization when filtering (default: `1e6`). + n_top_genes: Number of highly variable genes to retain (if filtering is applied). + filter_method: How to filter genes before DE analysis. One of `"scanpy"`, `"scran"`, or `"filterByExpr"`. + var_names: Optional list of variable (gene) names to restrict analysis to. Overrides filtering if provided. + de_method: Differential expression method: `"pydeseq2"` (default) or `"edger"`. + quiet: Whether to suppress console output from PyDESeq2 (default: True). + alpha: Significance threshold passed to DE test (used in PyDESeq2). + use_eb: If `True`, applies empirical Bayes shrinkage (not implemented yet, reserved for future limma-style methods). + **kwargs: Additional arguments passed to filtering or DE methods (e.g., `min_expr`, `min_total` for edgeR filtering). + + Returns: + :class:`pandas.DataFrame`: DE results with columns: + - "variable": Gene/feature name + - "log_fc": log2 fold change + - "p_value": Unadjusted p-value + - "adj_p_value": Multiple testing-corrected p-value + - "group": (only for one-vs-rest) the group being compared vs rest + + Examples: + >>> import pertpy as pt + >>> adata = pt.dt.bhattacherjee() + >>> milo = pt.tl.Milo() + >>> mdata = milo.load(adata) + >>> sc.pp.neighbors(mdata["rna"]) + >>> milo.make_nhoods(mdata["rna"]) + >>> mdata = milo.count_nhoods(mdata, sample_col="orig.ident") + >>> milo.da_nhoods(mdata, design="~label") + >>> milo.group_nhoods(mdata) + >>> milo.annotate_cells_from_nhoods(mdata) + >>> milo.find_nhood_group_markers(mdata, group_to_compare="3", baseline="1", nhood_group_obs="nhood_groups") + >>> milo.find_nhood_group_markers(mdata, nhood_group_obs="nhood_groups") + """ + func = pseudobulk_function + + # 1) Subset to cells that have a non‐NA group label + if isinstance(data, AnnData): + adata = data + elif isinstance(data, MuData): + mdata = data + if key not in mdata.mod_names: + raise KeyError(f"Modality '{key}' not found in mdata; available keys: {list(mdata.keys())}") + adata = mdata[key] + else: + raise TypeError("data must be an AnnData or MuData object.") + + if nhood_group_obs not in adata.obs.columns: + raise KeyError(f"Column '{nhood_group_obs}' not found in adata.obs") + + from pandas.api.types import CategoricalDtype + + if not isinstance(adata.obs[nhood_group_obs].dtype, CategoricalDtype): + adata.obs[nhood_group_obs] = adata.obs[nhood_group_obs].astype("category") + + n_non_na = adata.obs[nhood_group_obs].notna().sum() + if n_non_na == 0: + raise ValueError(f"No non‐NA entries found in '{nhood_group_obs}'") + + if sample_col not in adata.obs.columns: + raise KeyError(f"sample_col '{sample_col}' not in adata.obs") + for cov in covariates or []: + if cov not in adata.obs.columns: + raise KeyError(f"Covariate '{cov}' not found in adata.obs") + + # If you expect “sum” or “mean” you might leave as is; if using a custom layer, check name: + if pseudobulk_function not in ("sum", "mean"): + raise KeyError(f"pseudobulk_function '{pseudobulk_function}' is not in 'sum'/'mean'") + + if var_names is not None: + missing = set(var_names) - set(adata.var_names) + if missing: + raise KeyError(f"These var_names are not in adata.var_names: {missing}") + + if group_to_compare is not None or baseline is not None: + levels = adata.obs[nhood_group_obs].cat.categories.tolist() + if group_to_compare not in levels: + raise ValueError(f"group_to_compare '{group_to_compare}' not a level of '{nhood_group_obs}' ({levels})") + if baseline not in levels: + raise ValueError(f"baseline '{baseline}' not a level of '{nhood_group_obs}' ({levels})") + if group_to_compare == baseline: + raise ValueError("group_to_compare and baseline cannot be the same") + + mask = adata.obs[nhood_group_obs].notna() + tmp_data = adata[mask] + + # 2) Build the list of categorical variables to aggregate by + covariates = [] if covariates is None else list(covariates) + + if sample_col in covariates: + all_variables = [nhood_group_obs] + covariates + else: + all_variables = [sample_col, nhood_group_obs] + covariates + + # 3) Pseudobulk aggregation + pdata = sc.get.aggregate(tmp_data, by=all_variables, func=func, axis=0, layer=layer) + pdata.X = pdata.layers[func].copy() + + if pdata.obs[nhood_group_obs].nunique() < 2: + raise ValueError(f"After aggregation, '{nhood_group_obs}' has fewer than 2 groups; DEA cannot proceed.") + + if group_to_compare is None and baseline is None: + levels_after = pdata.obs[nhood_group_obs].cat.categories.tolist() + if len(levels_after) < 2: + raise ValueError( + f"Need at least two groups in '{nhood_group_obs}' to run one‐vs‐rest; found {levels_after}" + ) + + # Build the design formula string + if not covariates: + base_formula = "~" + nhood_group_obs + else: + base_formula = "~" + " + ".join(covariates) + " + " + nhood_group_obs + + if var_names is not None: + missing = set(var_names) - set(pdata.var_names) + if missing: + raise KeyError(f"Some var_names not found in pdata.var_names: {missing}") + # In‐place subset to exactly var_names: + pdata._inplace_subset_var(var_names) + # 2) If no var_names, but n_top_genes is None or zero skip filtering. + elif not n_top_genes: + pass + # 3) var_names is None and n_top_genes is a positive integer + elif filter_method == "scanpy": + self._filter_highly_variable_scanpy(pdata, n_top_genes, target_sum) + elif filter_method == "scran": + self._filter_highly_variable_scran(pdata, n_top_genes) + elif filter_method == "filterByExpr": + import inspect + + sig = inspect.signature(self._filter_by_expr_edger) + valid_filter_keys = set(sig.parameters) + filter_kwargs = {} + for kwargs_key in valid_filter_keys & set(kwargs): + filter_kwargs[kwargs_key] = kwargs.pop(key) + if not filter_kwargs: + filter_kwargs = {"min_expr": 1, "min_total": 10, "min_prop": 0.1} + if not _is_counts(pdata.X): + raise ValueError("`pdata.X` appears to be continuous expression, but filterByExpr requires raw counts.") + self._filter_by_expr_edger(pdata, base_formula, **filter_kwargs) + else: + raise ValueError(f"filter_method must be 'scanpy', 'scran' or 'filterByExpr', not '{filter_method}'") + + if de_method == "pydeseq2": + if not _is_counts(pdata.X): + raise ValueError("`pdata.X` appears to be raw counts, but this function expects raw counts.") + return self._run_pydeseq2_contrasts( + pdata, + nhood_group_obs=nhood_group_obs, + formula=base_formula, + group_to_compare=group_to_compare, + baseline=baseline, + alpha=alpha, + quiet=quiet, + ) + + if de_method == "edger": + if not _is_counts(pdata.X): + raise ValueError("`pdata.X` appears to be raw counts, but this function expects raw counts.") + return self._run_edger_contrasts( + pdata, + nhood_group_obs=nhood_group_obs, + formula=base_formula, + group_to_compare=group_to_compare, + baseline=baseline, + ) + else: + raise ValueError(f"de_method must be one of 'pydeseq2' or 'edger', not '{de_method}'") + + # def plot_heatmap_with_dot_and_colorbar( + # self, + # mean_df: pd.DataFrame, + # logfc_ser: pd.Series | None = None, + # cmap: str = "YlGnBu", + # dot_scale: float = 200.0, + # figsize: tuple[float, float] = (6, 10), + # panel_ratios: tuple[float, float, float] = (5, 0.6, 0.3), + # cbar_tick_count: int = 5, + # show_dot: bool = True, + # legend_on_right: bool = False, + # ) -> plt.Figure: + # """Marker heatmap of mean expression across groups, with optional logFC dots and a colorbar. + + # Plot a figure with: + # • Left: heatmap of mean_df (genes × groups), WITHOUT its default colorbar. + # • (Optional) Middle: a single column of dots (size ∝ |logFC|), one per gene. + # • Right: a slim vertical colorbar (ggplot2 style) that applies to the heatmap. + # • (Optional) A size legend for logFC dots, either just to the right of the colorbar + # (legend_on_right=False, the default) or further to the right (legend_on_right=True). + + # If show_dot=False, `logfc_ser` may be omitted (and is ignored). If show_dot=True, + # then `logfc_ser` must be provided and must match `mean_df.index`. + + # Parameters + # ---------- + # mean_df : pandas.DataFrame, shape (n_genes, n_groups) + # Rows = gene names; columns = group labels; values = mean expression. + + # logfc_ser : pandas.Series or None, default=None + # If show_dot=True, this Series of length n_genes (indexed by gene names) gives + # the logFC for each gene. If show_dot=False, you may leave this as None. + + # cmap : str, default="YlGnBu" + # Colormap for the heatmap and its colorbar. + + # dot_scale : float, default=200.0 + # Controls the maximum dot area for the largest |logFC| (only used if show_dot=True). + + # figsize : tuple (W, H), default=(6, 10) + # Total figure size in inches. Width W is split among panels according to ratios. + + # panel_ratios : tuple (r1, r2, r3), default=(5, 0.6, 0.3) + # Relative widths for [heatmap, dot‐column, colorbar] when show_dot=True. + # If show_dot=False, only r1 and r3 are used to split the width. + + # cbar_tick_count : int, default=5 + # Number of ticks on the vertical colorbar. + + # show_dot : bool, default=True + # If True, draw the dot column (requires `logfc_ser`). If False, omit dots and + # only draw [heatmap | colorbar]. + + # legend_on_right : bool, default=False + # If True, move the “size legend” further to the right of the figure, + # to avoid overlap when the figure is narrow. If False, place it just + # to the right of the colorbar (may overlap if figure is very narrow). + + # Returns: + # ------- + # fig : matplotlib.figure.Figure + + # Examples: + # >>> varnames = ( + # >>> nhood_group_markers_results + # >>> .query('logFC >= 0.5') + # >>> .query('adj_PValue <= 0.01') + # >>> .sort_values("logFC", ascending = False) + # >>> .variable.to_list() + # >>> ) + # >>> mean_df = milo.get_mean_expression(mdata["rna"], "nhood_groups", var_names=varnames) + # >>> logfc_ser = ( + # >>> nhood_group_markers_results + # >>> .query('logFC >= 0.5') + # >>> .query('adj_PValue <= 0.01') + # >>> .set_index("variable") + # >>> .logFC + # >>> ) + # >>> fig = milo.plot_heatmap_with_dot_and_colorbar( + # >>> mean_df, + # >>> logfc_ser=logfc_ser, + # >>> cmap="YlGnBu", + # >>> dot_scale=200.0, + # >>> figsize=(2, (1.5, len(logfc_ser)*0.15)), + # >>> panel_ratios=(5, 0.6, 0.3), + # >>> cbar_tick_count=5, + # >>> show_dot=True, + # >>> legend_on_right=1.3, + # >>> ) + + # """ + # # ──────────────────────────────── + # # 1) Validate / align logFC + # # ──────────────────────────────── + # if show_dot: + # if logfc_ser is None: + # raise ValueError("`logfc_ser` must be provided when `show_dot=True`.") + # genes = list(mean_df.index) + # lfc_vals = logfc_ser.reindex(index=genes).fillna(0.0).values + # n_genes = len(genes) + # else: + # genes = list(mean_df.index) + # n_genes = len(genes) + # lfc_vals = None + + # groups = list(mean_df.columns) + + # # ──────────────────────────────── + # # 2) Dot‐size scaling (if needed) + # # ──────────────────────────────── + # if show_dot: + # max_abs_lfc = np.nanmax(np.abs(lfc_vals)) + # if max_abs_lfc == 0 or np.isnan(max_abs_lfc): + # max_abs_lfc = 1.0 + + # # ──────────────────────────────── + # # 3) Heatmap normalization + # # ──────────────────────────────── + # vmin = mean_df.values.min() + # vmax = mean_df.values.max() + # norm = Normalize(vmin=vmin, vmax=vmax) + # cmap_obj = plt.get_cmap(cmap) + + # # ──────────────────────────────── + # # 4) Build a GridSpec + # # ──────────────────────────────── + # W, H = figsize + # r1, r2, r3 = panel_ratios + + # if show_dot: + # # three panels: [heatmap | dots | colorbar] + # total_ratio = r1 + r2 + r3 + # width_ratios = [r1 / total_ratio, r2 / total_ratio, r3 / total_ratio] + # fig = plt.figure(figsize=(W, H)) + # gs = fig.add_gridspec(nrows=1, ncols=3, width_ratios=width_ratios, wspace=0.02) + # ax_heat = fig.add_subplot(gs[0, 0]) + # else: + # # two panels: [heatmap | colorbar] + # total_ratio = r1 + r3 + # width_ratios = [r1 / total_ratio, r3 / total_ratio] + # fig = plt.figure(figsize=(W, H)) + # gs = fig.add_gridspec(nrows=1, ncols=2, width_ratios=width_ratios, wspace=0.02) + # ax_heat = fig.add_subplot(gs[0, 0]) + + # # ──────────────────────────────── + # # 5) Plot heatmap (no default colorbar) + # # ──────────────────────────────── + # sns.heatmap( + # mean_df, + # ax=ax_heat, + # cmap=cmap, + # norm=norm, + # cbar=False, + # yticklabels=genes, + # xticklabels=groups, + # linewidths=0.5, + # linecolor="gray", + # ) + # ax_heat.set_ylabel("Gene", fontsize=10) + # ax_heat.set_xlabel("Group", fontsize=10) + # plt.setp(ax_heat.get_xticklabels(), rotation=45, ha="right", fontsize=8) + # plt.setp(ax_heat.get_yticklabels(), rotation=0, fontsize=6) + + # # ──────────────────────────────── + # # 6) Dot panel (if requested) + # # ──────────────────────────────── + # if show_dot: + # ax_dot = fig.add_subplot(gs[0, 1]) + # for i, val in enumerate(lfc_vals): + # if not np.isnan(val) and val != 0.0: + # area = (abs(val) / max_abs_lfc) * dot_scale + # ax_dot.scatter(0, i, s=area, color="black", alpha=0.8, edgecolors="none") + # ax_dot.set_xlim(-0.5, 0.5) + # ax_dot.set_ylim(n_genes - 0.5, -0.5) + # ax_dot.set_xticks([]) + # ax_dot.set_yticks([]) + # ax_dot.set_title("logFC", pad=10, fontdict={"fontsize": 7}) + # ax_cbar = fig.add_subplot(gs[0, 2]) + # else: + # ax_cbar = fig.add_subplot(gs[0, 1]) + + # # ──────────────────────────────── + # # 7) Draw vertical colorbar for heatmap + # # ──────────────────────────────── + # smap = ScalarMappable(norm=norm, cmap=cmap_obj) + # smap.set_array([]) + + # cbar = fig.colorbar( + # smap, cax=ax_cbar, orientation="vertical", ticks=np.linspace(vmin, vmax, num=cbar_tick_count) + # ) + # cbar.ax.tick_params(labelsize=8, length=4, width=1) + # cbar.ax.set_title("Mean\nExpr.", fontsize=8, pad=6) + # cbar.outline.set_linewidth(0.5) + + # # ──────────────────────────────── + # # 8) Add a size‐legend for the dot‐column (optional) + # # ──────────────────────────────── + # if show_dot: + # # Choose three reference |logFC| values: max, ½ max, ¼ max + # ref_vals = np.array([max_abs_lfc, 0.5 * max_abs_lfc, 0.25 * max_abs_lfc]) + # legend_handles = [] + # legend_labels = [] + # for rv in ref_vals: + # sz = (rv / max_abs_lfc) * dot_scale + # handle = ax_dot.scatter(0, 0, s=sz, color="black", alpha=0.8, edgecolors="none") + # legend_handles.append(handle) + # legend_labels.append(f"|logFC| = {rv:.2f}") + + # # Determine bounding box based on legend_on_right flag + # bbox_x = (1.2 if isinstance(legend_on_right, bool) else legend_on_right) if legend_on_right else 1.02 + + # fig.legend( + # legend_handles, + # legend_labels, + # title="Dot size legend", + # loc="center left", + # bbox_to_anchor=(bbox_x, 0.5), + # frameon=False, + # fontsize=7, + # title_fontsize=8, + # handletextpad=0.5, + # labelspacing=0.6, + # ) + + # plt.tight_layout() + # return fig diff --git a/tests/tools/test_milo.py b/tests/tools/test_milo.py index 8264ab4a..32f5924d 100644 --- a/tests/tools/test_milo.py +++ b/tests/tools/test_milo.py @@ -5,6 +5,7 @@ import pertpy as pt import pytest import scanpy as sc +import scipy.sparse as sp from mudata import MuData @@ -307,3 +308,722 @@ def test_add_nhood_expression_nhood_mean_range(add_nhood_expression_mdata, milo) nhood_cells = mdata["rna"].obs_names[mdata["rna"].obsm["nhoods"][:, nhood_ix].toarray().ravel() == 1] mean_gex = np.array(mdata["rna"][nhood_cells].X.mean(axis=0)).ravel() assert nhood_gex == pytest.approx(mean_gex, 0.0001) + + +### NEW TESTS + +from scipy.sparse import csr_matrix + + +@pytest.fixture +def group_nhoods_mdata(adata, milo): + adata = adata.copy() + milo.make_nhoods(adata) + + # Simulate experimental condition + rng = np.random.default_rng(seed=42) + adata.obs["condition"] = rng.choice(["ConditionA", "ConditionB"], size=adata.n_obs, p=[0.5, 0.5]) + # we simulate differential abundance in NK cells + DA_cells = adata.obs["louvain"] == "1" + adata.obs.loc[DA_cells, "condition"] = rng.choice(["ConditionA", "ConditionB"], size=sum(DA_cells), p=[0.2, 0.8]) + + # Simulate replicates + adata.obs["replicate"] = rng.choice(["R1", "R2", "R3"], size=adata.n_obs) + adata.obs["sample"] = adata.obs["replicate"] + adata.obs["condition"] + milo_mdata = milo.count_nhoods(adata, sample_col="sample") + milo.da_nhoods(milo_mdata, design="~condition", solver="pydeseq2") + + var = milo_mdata["milo"].var.copy() + + n = var.shape[0] + k = max(1, int(0.1 * n)) # e.g. guarantee 10% are “significant,” at least 1 + + # fdrs = np.random.rand(n) + # New Generator interface: + rng = np.random.default_rng() + fdrs = rng.random(n) + da_idx = rng.choice(n, size=k, replace=False) + + # da_idx = np.random.choice(n, size=k, replace=False) + # fdrs[da_idx] = np.random.rand(k) * 0.1 + fdrs[da_idx] = rng.random(k) * 0.1 + + # np.random.shuffle(fdrs) + rng.shuffle(fdrs) + milo_mdata["milo"].var["SpatialFDR"] = fdrs + + milo.build_nhood_graph(milo_mdata) + return milo_mdata + + +def csr_to_r_dgCMatrix(csr: csr_matrix): + """ + Convert a SciPy CSR matrix into an R dgCMatrix using rpy2. + + Returns an rpy2 Matrix object (class “dgCMatrix”). + """ + import rpy2.robjects as ro + from rpy2.robjects import FloatVector, IntVector, numpy2ri + from rpy2.robjects.conversion import localconverter + + # 1) Ensure CSR is in COOrdinate form to extract row/col/data + coo = csr.tocoo() + # R is 1-based, so we must add 1 to Python’s 0-based indices: + i_r = (coo.row + 1).astype(int) + j_r = (coo.col + 1).astype(int) + x_r = coo.data + + # 2) Load the Matrix package in R (only if not already loaded) + ro.r("suppressPackageStartupMessages(library(Matrix))") + + # 3) Build the sparseMatrix(...) call in R + # - `sparseMatrix(i=..., j=..., x=..., dims=c(nrow, ncol))` returns a dgCMatrix by default. + nrow, ncol = csr.shape + r_sparse = ro.r["sparseMatrix"] + + # 4) Call sparseMatrix(i=IntVector(i_r), j=IntVector(j_r), x=FloatVector(x_r), dims=c(nrow, ncol)) + with localconverter(ro.default_converter + numpy2ri.converter): + # Pass `dims` as an IntVector of length 2 + dims_vec = IntVector([int(nrow), int(ncol)]) + mat_r = r_sparse( + i=IntVector(i_r.tolist()), + j=IntVector(j_r.tolist()), + x=FloatVector(x_r.tolist()), + dims=dims_vec, + index1=ro.BoolVector([True]), # tell R that i,j are 1-based + ) + + return mat_r + + +# def _py_to_r(obj): +# """ +# Convert a Python object into an R object, using only context‐managed converters. +# - Any 2D array‐like → force to numpy.ndarray, then to R matrix +# - scipy.sparse → dense numpy → R matrix +# - pandas.DataFrame → R data.frame +# - pandas.Series → logical/int/float/character vector +# - 1D numpy array → R vector +# - Python scalar / single‐item list → length‐1 R vector +# - None → R NULL +# """ +# # Import rpy2 constructs lazily +# import rpy2.robjects as ro +# from rpy2.robjects import numpy2ri, pandas2ri +# from rpy2.robjects.conversion import localconverter + +# # (A) scipy.sparse → dense numpy → R matrix +# if sp.issparse(obj): +# arr = obj.toarray() +# rmat = numpy2ri.py2rpy(arr) +# return rmat + +# # (B) pandas.DataFrame → R data.frame +# if isinstance(obj, pd.DataFrame): +# with localconverter(ro.default_converter + pandas2ri.converter): +# return pandas2ri.py2rpy(obj) + +# # (C) pandas.Series → logical/int/float/character R vector +# if isinstance(obj, pd.Series): +# if obj.dtype == bool: +# return ro.BoolVector(obj.values.tolist()) +# elif np.issubdtype(obj.dtype, np.integer): +# return ro.IntVector(obj.values.tolist()) +# elif np.issubdtype(obj.dtype, np.floating): +# return ro.FloatVector(obj.values.tolist()) +# else: +# return ro.StrVector(obj.astype(str).tolist()) + +# # (D) Force anything array‐like into a NumPy array +# try: +# arr = np.asarray(obj) +# except Exception: +# arr = None + +# if isinstance(arr, np.ndarray): +# # 2D array → R matrix +# rmat = numpy2ri.py2rpy(arr) + +# # 1D array → R vector +# if arr.ndim == 1: +# if arr.dtype == bool: +# return ro.BoolVector(arr.tolist()) +# elif np.issubdtype(arr.dtype, np.integer): +# return ro.IntVector(arr.tolist()) +# elif np.issubdtype(arr.dtype, np.floating): +# return ro.FloatVector(arr.tolist()) +# else: +# return ro.StrVector(arr.astype(str).tolist()) + +# # (E) Python scalar or single‐item list/tuple → length‐1 R vector +# if isinstance(obj, bool): +# return ro.BoolVector([obj]) +# if isinstance(obj, int | np.integer): +# return ro.IntVector([int(obj)]) +# if isinstance(obj, float | np.floating): +# return ro.FloatVector([float(obj)]) +# if isinstance(obj, str): +# return ro.StrVector([obj]) + +# # (F) None → R NULL +# if obj is None: +# return ro.NULL + +# # (G) Python list of simple types → convert to numpy array then recurse +# if isinstance(obj, list): +# return _py_to_r(np.asarray(obj)) + +# # (H) Otherwise, cannot convert +# raise ValueError(f"Cannot convert object of type {type(obj)} to R.") + + +def _py_to_r(obj): + import rpy2.robjects as ro + from rpy2.robjects import numpy2ri, pandas2ri + from rpy2.robjects.conversion import localconverter + + if isinstance(obj, np.ndarray): + with localconverter(ro.default_converter + numpy2ri.converter): + r_obj = numpy2ri.py2rpy(obj) + return r_obj + if isinstance(obj, pd.DataFrame): + with localconverter(ro.default_converter + pandas2ri.converter): + df = pandas2ri.py2rpy(obj) + return df + if obj is None: + return ro.NULL + with localconverter(ro.default_converter): + r_obj = ro.conversion.py2rpy(obj) + return r_obj + + +def _group_nhoods_from_adjacency_r( + nhs_r, nhood_adj_r, da_res_r, is_da_r, merge_discord_r, max_lfc_delta_r, overlap_r, subset_nhoods_r +): + """ + Lazily define the R function .group_nhoods_from_adjacency_pycomp in R’s global env, + then call it directly with arguments that are already R objects. + Returns an R matrix of 0/1. + """ + # Import rpy2 inside the function + import rpy2.robjects as ro + + # Define the R function in R’s global environment: + rcode = r""" + .group_nhoods_from_adjacency_pycomp <- function(nhs, nhood.adj, da.res, is.da, + merge.discord=FALSE, + max.lfc.delta=NULL, + overlap=1, + subset.nhoods=NULL + ){ + # Force everything into a plain base‐R matrix + nhood.adj <- as.matrix(nhood.adj) + + if(is.null(colnames(nhs))){ + warning("No names attributed to nhoods. Converting indices to names") + colnames(nhs) <- as.character(seq_len(ncol(nhs))) + } + + # Subsetting logic (as in miloR) + if(!is.null(subset.nhoods)){ + if(mode(subset.nhoods) %in% c("character", "logical", "numeric")){ + if(mode(subset.nhoods) %in% c("character")){ + sub.log <- colnames(nhs) %in% subset.nhoods + } else if (mode(subset.nhoods) %in% c("numeric")) { + sub.log <- colnames(nhs) %in% colnames(nhs)[subset.nhoods] + } else{ + sub.log <- subset.nhoods + } + nhood.adj <- nhood.adj[sub.log, sub.log] + if(length(is.da) == ncol(nhs)){ + nhs <- nhs[sub.log] + is.da <- is.da[sub.log] + da.res <- da.res[sub.log, ] + } else{ + stop("Subsetting `is.da` vector length does not equal nhoods length") + } + } else{ + stop("Incorrect subsetting vector provided:", class(subset.nhoods)) + } + } else{ + if(length(is.da) != ncol(nhood.adj)){ + stop("Subsetting `is.da` vector length is not the same dimension as adjacency") + } + } + + # Discord‐filter + if(isFALSE(merge.discord)){ + discord.sign <- sign(da.res[is.da, 'logFC'] %*% t(da.res[is.da, 'logFC'])) < 0 + nhood.adj[is.da, is.da][discord.sign] <- 0 + } + + # Overlap‐filter + if(overlap > 1){ + nhood.adj[nhood.adj < overlap] <- 0 + } + + # max.lfc.delta‐filter + if(!is.null(max.lfc.delta)){ + lfc.diff <- sapply(da.res[,"logFC"], "-", da.res[,"logFC"]) + nhood.adj[abs(lfc.diff) > max.lfc.delta] <- 0 + } + + # Binarize + nhood.adj <- as.matrix((nhood.adj > 0) + 0) + + # Sanity checks + if(!isSymmetric(nhood.adj)){ + stop("Overlap matrix is not symmetric") + } + if(nrow(nhood.adj) != ncol(nhood.adj)){ + stop("Non-square distance matrix ‐ check nhood subsetting") + } + + return(nhood.adj) + } + """ + # Evaluate rcode in R's global environment (defines the function): + ro.r(rcode) + + # Now retrieve that function from R's globalenv and call it: + f = ro.globalenv[".group_nhoods_from_adjacency_pycomp"] + return f(nhs_r, nhood_adj_r, da_res_r, is_da_r, merge_discord_r, max_lfc_delta_r, overlap_r, subset_nhoods_r) + + +def _group_nhoods_from_adjacency_rcomp( + adjacency: sp.spmatrix, + da_res: pd.DataFrame, + is_da: np.ndarray, + merge_discord: bool = False, + overlap: int = 1, + max_lfc_delta: float | None = None, + subset_nhoods=None, +) -> np.ndarray: + """ + Pure‐Python implementation. + """ + # 1) Subset if needed + if subset_nhoods is not None: + if isinstance(subset_nhoods, list | np.ndarray): + arr = np.asarray(subset_nhoods) + if np.issubdtype(arr.dtype, np.integer): + mask = np.zeros(adjacency.shape[0], dtype=bool) + mask[arr.astype(int)] = True + else: + names = np.array(da_res.index, dtype=str) + mask = np.isin(names, arr.astype(str)) + elif isinstance(subset_nhoods, pd.Series | np.ndarray) and getattr(subset_nhoods, "dtype", None) is bool: + if len(subset_nhoods) != adjacency.shape[0]: + raise ValueError("Boolean subset_nhoods length must match nhood count") + mask = np.asarray(subset_nhoods, dtype=bool) + else: + raise ValueError("subset_nhoods must be bool mask, index list, or name list") + + adjacency = adjacency[mask, :][:, mask] + da_res = da_res.loc[mask].copy() + is_da = is_da[mask] + else: + mask = np.ones(adjacency.shape[0], dtype=bool) + + M = adjacency.shape[0] + if da_res.shape[0] != M or is_da.shape[0] != M: + raise ValueError("da_res and is_da must match adjacency dimension after subsetting") + + # 2) Ensure CSR → COO + if not sp.issparse(adjacency): + adjacency = sp.csr_matrix(adjacency) + adjacency = adjacency.tocsr() + Acoo = adjacency.tocoo() + rows, cols, data = (np.asarray(Acoo.row, int), np.asarray(Acoo.col, int), np.asarray(Acoo.data, float)) + + # 3) Precompute logFC and signs + lfc_vals = da_res["logFC"].values + signs = np.sign(lfc_vals) + + # 4.1) Discord filter + if merge_discord: + keep_discord = np.ones_like(data, dtype=bool) + else: + is_da_rows = is_da[rows] + is_da_cols = is_da[cols] + sign_rows = signs[rows] + sign_cols = signs[cols] + discord_pair = (is_da_rows & is_da_cols) & (sign_rows * sign_cols < 0) + keep_discord = ~discord_pair + + # 4.2) Overlap filter + keep_overlap = np.ones_like(data, dtype=bool) if overlap <= 1 else data >= overlap + + # 4.3) ΔlogFC filter + if max_lfc_delta is None: + keep_lfc = np.ones_like(data, dtype=bool) + else: + diffs = np.abs(lfc_vals[rows] - lfc_vals[cols]) + keep_lfc = diffs <= max_lfc_delta + + # 5) Combine masks + keep_mask = keep_discord & keep_overlap & keep_lfc + + # 6) Reconstruct pruned adjacency, then binarize + new_rows = rows[keep_mask] + new_cols = cols[keep_mask] + new_data = data[keep_mask] + pruned = sp.coo_matrix((new_data, (new_rows, new_cols)), shape=(M, M)).tocsr() + pruned_bin = (pruned > 0).astype(int).toarray() + + return pruned_bin + + +@pytest.mark.parametrize( + "merge_discord_flag, overlap_val, max_lfc_val", + [ + (False, 1, None), + (True, 1, None), + (False, 2, None), + (False, 3, None), + (False, 5, None), + (False, 15, None), + (False, 100, None), + (False, 1, 0.5), + (False, 1, 1.0), + (False, 1, 2.0), + (False, 1, 3.0), + (False, 1, 4.0), + (False, 1, 5.0), + ], +) +def test_sparse_adjacency_filters_match_R(group_nhoods_mdata, merge_discord_flag, overlap_val, max_lfc_val): + """ + Compare the R version against the Python version for various + settings of merge_discord, overlap, and max_lfc_delta. + """ + # 1) Extract inputs from fixture + mdata = group_nhoods_mdata.copy() + nhs = mdata["rna"].obsm["nhoods"].copy() + nhood_adj = mdata["milo"].varp["nhood_connectivities"].copy() # sparse + da_res = mdata["milo"].var.copy() # DataFrame with "logFC" + is_da = (da_res["SpatialFDR"].values < 0.1) & (da_res["logFC"].values > 0) + + # 2) Convert to R objects + nhs_r = _py_to_r(np.zeros((nhs.shape[0], nhs.shape[1]))) + nhood_adj_r = csr_to_r_dgCMatrix(nhood_adj) + da_res_r = _py_to_r(da_res) + is_da_r = _py_to_r(is_da) + merge_discord_r = _py_to_r(bool(merge_discord_flag)) + overlap_r = _py_to_r(int(overlap_val)) + max_lfc_delta_r = _py_to_r(max_lfc_val) # None → R NULL + subset_nhoods_r = _py_to_r(None) # always None here + + # 3) Call the R implementation + r_out = _group_nhoods_from_adjacency_r( + nhs_r, nhood_adj_r, da_res_r, is_da_r, merge_discord_r, max_lfc_delta_r, overlap_r, subset_nhoods_r + ) + + # 4) Convert R output → NumPy + import rpy2.robjects as ro + from rpy2.robjects import numpy2ri + from rpy2.robjects.conversion import localconverter + + if isinstance(r_out, ro.vectors.Matrix): + with localconverter(ro.default_converter + numpy2ri.converter): + adj_R = np.asarray(r_out) + else: + adj_R = np.asarray(r_out) + + assert adj_R.shape == nhood_adj.shape + + # 5) Call the Python implementation + adj_py = _group_nhoods_from_adjacency_rcomp( + nhood_adj, + da_res, + is_da, + merge_discord=merge_discord_flag, + overlap=overlap_val, + max_lfc_delta=max_lfc_val, + subset_nhoods=None, + ) + + # 6) Compare + assert adj_py.shape == adj_R.shape + assert np.array_equal(adj_py, adj_R), ( + f"Mismatch for (merge_discord={merge_discord_flag}, overlap={overlap_val}, max_lfc={max_lfc_val})" + ) + + +@pytest.mark.parametrize( + "nhood_group_obs, subset_nhoods, min_n_nhoods, mode", + [ + # default obs-name, all neighborhoods, last-wins + ("nhood_groups", None, 1, "last_wins"), + # default obs-name, only neighborhoods "0" and "1", last-wins + ("nhood_groups", None, 1, "last_wins"), + # default obs-name, all neighborhoods, exclude overlaps with threshold 2 + ("nhood_groups", None, 2, "exclude_overlaps"), + # custom obs-name, only neighborhood "2", exclude overlaps w/ threshold 3 + ("custom_lbls", None, 3, "exclude_overlaps"), + ], +) +def test_annotate_cells_from_nhoods_various( + group_nhoods_mdata, + nhood_group_obs, + subset_nhoods, + min_n_nhoods, + mode, +): + from pertpy.tools._milo import Milo + + milo = Milo() + + mdata = group_nhoods_mdata.copy() + + assert "SpatialFDR" in mdata["milo"].var.columns + milo.group_nhoods(mdata) + if nhood_group_obs == "custom_lbls": + # Create a custom nhood group obs column + mdata["milo"].var["custom_lbls"] = mdata["milo"].var["nhood_groups"].copy() + + # run the annotation + milo.annotate_cells_from_nhoods( + mdata, + nhood_group_obs=nhood_group_obs, + subset_nhoods=subset_nhoods, + min_n_nhoods=min_n_nhoods, + mode=mode, + ) + + if nhood_group_obs == "custom_lbls": + # Create a custom nhood group obs column + mdata["rna"].obs["custom_lbls"] = mdata["rna"].obs["nhood_groups"].copy() + + # the new column must exist + assert nhood_group_obs in mdata["rna"].obs.columns + + col = mdata["rna"].obs[nhood_group_obs] + + # type checks + assert col.dtype == object + assert len(col) == mdata["rna"].n_obs + + # non-annotated cells should be NaN + # annotated cells should only use labels from the chosen neighborhoods + # non_null = col.dropna().astype(str).values + # non_null = [ + # x for x in col.astype(str).values + # if x.lower() not in ("nan", "") + # ] + + # if subset_nhoods is None: + # # allowed = set(mdata["milo"].var[nhood_group_obs].astype(str).unique()) + # allowed = { + # x for x in mdata["milo"].var[nhood_group_obs].astype(str).unique() + # if x.lower() != "nan" + # } + # else: + # allowed = set(subset_nhoods) + # 1) grab the raw Series + ser_obs = mdata["rna"].obs[nhood_group_obs] + ser_var = mdata["milo"].var[nhood_group_obs] + + # 2) drop real missing values + ser_obs_non_na = ser_obs[ser_obs.notna()] + ser_var_non_na = ser_var[ser_var.notna()] + + # 3) convert to str for comparison + non_null = set(ser_obs_non_na.astype(str).unique()) + allowed = set(ser_var_non_na.astype(str).unique()) + + assert non_null.issubset(allowed) + + # assert set(non_null).issubset(allowed) + + # For exclude_overlaps, ensure that no cell is assigned if it belongs + # to fewer than min_n_nhoods neighborhoods + if mode == "exclude_overlaps": + # build a quick membership count: + np.asarray(mdata["rna"].obs[nhood_group_obs].notna(), bool) + counts = (mdata["rna"].obsm["nhoods"].astype(int)).sum(axis=1) + too_few = counts < min_n_nhoods + too_few = np.asarray(too_few).ravel() + assert (col[too_few].isna()).all() + + +#### +# test find_nhood_group_markers + + +@pytest.fixture +def nhood_markers_mdata(milo): + # 1) Load pbmc3k and make a private copy + adata = sc.datasets.pbmc3k().copy() + + adata.layers["counts"] = adata.X.copy() # keep raw counts + sc.pp.normalize_total(adata, target_sum=1e4) # normalize + sc.pp.log1p(adata) # log transform + sc.pp.highly_variable_genes(adata, n_top_genes=2000) + sc.pp.pca(adata, n_comps=50) # PCA + sc.pp.neighbors(adata, n_neighbors=10, n_pcs=30) # build neighbors + sc.tl.louvain(adata, resolution=0.5) # Louvain clustering + sc.tl.umap(adata) + + # 2) Build neighborhoods + milo.make_nhoods(adata) + + # 3) Simulate an experimental condition + rng = np.random.default_rng(seed=42) + adata.obs["condition"] = rng.choice(["ConditionA", "ConditionB"], size=adata.n_obs, p=[0.5, 0.5]) + # bump one cluster to have DA + DA_cells = adata.obs["louvain"] == "1" + adata.obs.loc[DA_cells, "condition"] = rng.choice(["ConditionA", "ConditionB"], size=DA_cells.sum(), p=[0.2, 0.8]) + + # 4) Simulate replicates & build sample IDs + adata.obs["replicate"] = rng.choice(["R1", "R2", "R3"], size=adata.n_obs) + adata.obs["sample"] = adata.obs["replicate"] + "_" + adata.obs["condition"] + + # 5) Count & test DA neighborhoods + mdata = milo.count_nhoods(adata, sample_col="sample") + milo.da_nhoods(mdata, design="~condition", solver="pydeseq2") + + # 6) Overwrite SpatialFDR so we have ~10% “significant” at random + var = mdata["milo"].var + n = var.shape[0] + k = max(1, int(0.1 * n)) + fdrs = rng.random(n) + da_idx = rng.choice(n, size=k, replace=False) + fdrs[da_idx] = rng.random(k) * 0.1 + rng.shuffle(fdrs) + mdata["milo"].var["SpatialFDR"] = fdrs + + # 7) Build the neighborhood graph + milo.build_nhood_graph(mdata) + + milo.group_nhoods(mdata) + + # 8) Annotate cells from those nhoods + milo.annotate_cells_from_nhoods( + mdata, + nhood_group_obs="nhood_groups", # use the default + subset_nhoods=None, # annotate across all groups + min_n_nhoods=1, # default + mode="last_wins", # default + ) + + return mdata + + +@pytest.mark.parametrize( + "group_to_compare, baseline, expect_group_col, filter_method, de_method", + [ + ("1", "0", False, "scanpy", "pydeseq2"), # two‐level contrast: no “group” column + (None, None, True, "scanpy", "pydeseq2"), # one‐vs‐rest: should have “group” column + ("1", "0", False, "filterByExpr", "pydeseq2"), + (None, None, True, "filterByExpr", "pydeseq2"), + ("1", "0", False, "scran", "pydeseq2"), + (None, None, True, "scran", "pydeseq2"), + ("1", "0", False, "scanpy", "edger"), + (None, None, True, "scanpy", "edger"), + ("1", "0", False, "filterByExpr", "edger"), + (None, None, True, "filterByExpr", "edger"), + ("1", "0", False, "scran", "edger"), + (None, None, True, "scran", "edger"), + ], +) +def test_find_markers_returns_expected_structure( + nhood_markers_mdata, milo, group_to_compare, baseline, expect_group_col, filter_method, de_method +): + """Test that find_nhood_group_markers returns a properly shaped DataFrame + and that the expected columns are present or absent.""" + + mdata = nhood_markers_mdata.copy() + df = milo.find_nhood_group_markers( + mdata, + group_to_compare=group_to_compare, + baseline=baseline, + nhood_group_obs="nhood_groups", + sample_col="sample", + covariates=None, + layer="counts", + filter_method=filter_method, + de_method=de_method, + ) + # Basic sanity checks + assert isinstance(df, pd.DataFrame) + # Must have at least one row and one variable + assert df.shape[0] > 0 + # Check core columns + core_cols = {"variable", "log_fc", "p_value", "adj_p_value"} + assert core_cols.issubset(df.columns) + # “group” only appears in one‐vs‐rest branch + if expect_group_col: + assert "group" in df.columns + # Should see at least two distinct group labels + groups = set(df["group"].unique()) + assert len(groups) > 1 + else: + assert "group" not in df.columns + + +def test_find_markers_two_level_effect_sizes(nhood_markers_mdata, milo): + """In the two-level case, log_fc should not all be zero + and p-values should be between 0 and 1.""" + + mdata = nhood_markers_mdata.copy() + + rng = np.random.default_rng(seed=42) + mdata["rna"].obs["nhood_groups"] = pd.Categorical( + rng.choice(["ConditionA", "ConditionB", pd.NA], size=mdata["rna"].n_obs) + ) + + res = milo.find_nhood_group_markers( + mdata, + group_to_compare="ConditionB", + baseline="ConditionA", + nhood_group_obs="nhood_groups", + sample_col="sample", + covariates=None, + layer="counts", + ) + # Check that we got at least one positive and one negative log fold–change + assert (res["log_fc"] > 0).any() + assert (res["log_fc"] < 0).any() + # p-values and adjusted p-values should lie in [0,1] + # assert res["p_value"].notna().all() + res = res.query("p_value == p_value") + assert ((res["p_value"] >= 0) & (res["p_value"] <= 1)).all() + res = res.query("adj_p_value == adj_p_value") + assert ((res["adj_p_value"] >= 0) & (res["adj_p_value"] <= 1)).all() + + +def test_find_markers_invalid_args_raises(nhood_markers_mdata, milo): + """Check that missing or bogus contrast arguments raise a ValueError.""" + + mdata = nhood_markers_mdata.copy() + with pytest.raises(ValueError): + # baseline without group_to_compare + milo.find_nhood_group_markers( + mdata, + group_to_compare=None, + baseline="ConditionA", + nhood_group_obs="nhood_groups", + sample_col="sample", + layer="counts", + ) + with pytest.raises(ValueError): + # group_to_compare == baseline + milo.find_nhood_group_markers( + mdata, + group_to_compare="ConditionA", + baseline="ConditionA", + nhood_group_obs="nhood_groups", + sample_col="sample", + layer="counts", + ) + with pytest.raises(KeyError): + # non‐existent column name + milo.find_nhood_group_markers( + mdata, + group_to_compare="ConditionB", + baseline="ConditionA", + nhood_group_obs="not_a_column", + sample_col="sample", + layer="counts", + )