diff --git a/docs/usage/usage.md b/docs/usage/usage.md index 58a30c03..2a5ff02f 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -136,6 +136,7 @@ Pertpy provides utilities to conduct differential gene expression tests through tools.EdgeR tools.WilcoxonTest tools.TTest + tools.PermutationTest tools.Statsmodels ``` @@ -563,9 +564,9 @@ including cell line annotation, bulk RNA and protein expression data. Available databases for cell line metadata: -- [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/) -- [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/) -- [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/) +- [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/) +- [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/) +- [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/) ### Compound @@ -573,7 +574,7 @@ The Compound module enables the retrieval of various types of information relate Available databases for compound metadata: -- [PubChem](https://pubchem.ncbi.nlm.nih.gov/) +- [PubChem](https://pubchem.ncbi.nlm.nih.gov/) ### Mechanism of Action @@ -581,7 +582,7 @@ This module aims to retrieve metadata of mechanism of action studies related to Available databases for mechanism of action metadata: -- [CLUE](https://clue.io/) +- [CLUE](https://clue.io/) ### Drug @@ -589,7 +590,7 @@ This module allows for the retrieval of Drug target information. Available databases for drug metadata: -- [chembl](https://www.ebi.ac.uk/chembl/) +- [chembl](https://www.ebi.ac.uk/chembl/) ```{eval-rst} .. autosummary:: diff --git a/pertpy/tools/__init__.py b/pertpy/tools/__init__.py index 10565cee..8322e418 100644 --- a/pertpy/tools/__init__.py +++ b/pertpy/tools/__init__.py @@ -47,6 +47,7 @@ def __init__(self, *args, **kwargs): DE_EXTRAS = ["formulaic", "pydeseq2"] EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2 +PermutationTest = lazy_import("pertpy.tools._differential_gene_expression", "PermutationTest", DE_EXTRAS) PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS) Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"]) TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS) @@ -62,6 +63,7 @@ def __init__(self, *args, **kwargs): "PyDESeq2", "WilcoxonTest", "TTest", + "PermutationTest", "Statsmodels", "DistanceTest", "Distance", diff --git a/pertpy/tools/_differential_gene_expression/__init__.py b/pertpy/tools/_differential_gene_expression/__init__.py index 6cc925dd..19a3383c 100644 --- a/pertpy/tools/_differential_gene_expression/__init__.py +++ b/pertpy/tools/_differential_gene_expression/__init__.py @@ -2,7 +2,7 @@ from ._dge_comparison import DGEEVAL from ._edger import EdgeR from ._pydeseq2 import PyDESeq2 -from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest +from ._simple_tests import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest from ._statsmodels import Statsmodels __all__ = [ @@ -14,6 +14,7 @@ "SimpleComparisonBase", "WilcoxonTest", "TTest", + "PermutationTest", ] -AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest] +AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest, PermutationTest] diff --git a/pertpy/tools/_differential_gene_expression/_pydeseq2.py b/pertpy/tools/_differential_gene_expression/_pydeseq2.py index d4360e28..f191debc 100644 --- a/pertpy/tools/_differential_gene_expression/_pydeseq2.py +++ b/pertpy/tools/_differential_gene_expression/_pydeseq2.py @@ -42,7 +42,7 @@ def fit(self, **kwargs) -> pd.DataFrame: **kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed. """ try: - usable_cpus = len(os.sched_getaffinity(0)) + usable_cpus = len(os.sched_getaffinity(0)) # type: ignore # os.sched_getaffinity is not available on Windows and macOS except AttributeError: usable_cpus = os.cpu_count() diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 824569f9..26e8fab8 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -2,7 +2,8 @@ import warnings from abc import abstractmethod -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence +from inspect import signature from types import MappingProxyType import numpy as np @@ -10,6 +11,7 @@ import scipy.stats import statsmodels from anndata import AnnData +from lamin_utils import logger from pandas.core.api import DataFrame as DataFrame from scipy.sparse import diags, issparse from tqdm.auto import tqdm @@ -33,7 +35,7 @@ def fdr_correction( class SimpleComparisonBase(MethodBase): @staticmethod @abstractmethod - def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: + def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: """Perform a statistical test between values in x0 and x1. If `paired` is True, x0 and x1 must be of the same length and ordered such that @@ -71,16 +73,16 @@ def _compare_single_group( x0 = x0.tocsc() x1 = x1.tocsc() - res = [] + res: list[dict[str, float | np.ndarray]] = [] for var in tqdm(self.adata.var_names): tmp_x0 = x0[:, self.adata.var_names == var] tmp_x0 = np.asarray(tmp_x0.todense()).flatten() if issparse(tmp_x0) else tmp_x0.flatten() tmp_x1 = x1[:, self.adata.var_names == var] tmp_x1 = np.asarray(tmp_x1.todense()).flatten() if issparse(tmp_x1) else tmp_x1.flatten() - pval = self._test(tmp_x0, tmp_x1, paired, **kwargs) + test_result = self._test(tmp_x0, tmp_x1, paired, **kwargs) mean_x0 = np.mean(tmp_x0) mean_x1 = np.mean(tmp_x1) - res.append({"variable": var, "p_value": pval, "log_fc": np.log2(mean_x1) - np.log2(mean_x0)}) + res.append({"variable": var, "log_fc": np.log2(mean_x1) - np.log2(mean_x0), **test_result}) return pd.DataFrame(res).sort_values("p_value") @classmethod @@ -94,9 +96,28 @@ def compare_groups( paired_by: str | None = None, mask: str | None = None, layer: str | None = None, + n_permutations: int = 1000, + permutation_test_statistic: type["SimpleComparisonBase"] | None = None, fit_kwargs: Mapping = MappingProxyType({}), test_kwargs: Mapping = MappingProxyType({}), ) -> DataFrame: + """Perform a comparison between groups. + + Args: + adata: Data with observations to compare. + column: Column in `adata.obs` that contains the groups to compare. + baseline: Reference group. + groups_to_compare: Groups to compare against the baseline. If None, all other groups + are compared. + paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed. + mask: Mask to apply to the data. + layer: Layer to use for the comparison. + n_permutations: Number of permutations to perform if a permutation test is used. + permutation_test_statistic: The statistic to use if performing a permutation test. If None, the default + t-statistic from `TTest` is used. + fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify. + test_kwargs: Additional kwargs passed to the test function. + """ if len(fit_kwargs): warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) paired = paired_by is not None @@ -125,6 +146,12 @@ def _get_idx(column, value): else: return np.where(mask)[0] + if permutation_test_statistic: + test_kwargs = dict(test_kwargs) + test_kwargs.update({"test_statistic": permutation_test_statistic, "n_permutations": n_permutations}) + elif permutation_test_statistic is None and cls.__name__ == "PermutationTest": + logger.warning("No permutation test statistic specified. Using TTest statistic as default.") + res_dfs = [] baseline_idx = _get_idx(column, baseline) for group_to_compare in groups_to_compare: @@ -144,19 +171,118 @@ class WilcoxonTest(SimpleComparisonBase): """ @staticmethod - def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: + def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: + """Perform an unpaired or paired Wilcoxon/Mann-Whitney-U test.""" + if paired: - return scipy.stats.wilcoxon(x0, x1, **kwargs).pvalue + test_result = scipy.stats.wilcoxon(x0, x1, **kwargs) else: - return scipy.stats.mannwhitneyu(x0, x1, **kwargs).pvalue + test_result = scipy.stats.mannwhitneyu(x0, x1, **kwargs) + + return { + "p_value": test_result.pvalue, + "statistic": test_result.statistic, + } class TTest(SimpleComparisonBase): - """Perform a unpaired or paired T-test""" + """Perform a unpaired or paired T-test.""" @staticmethod - def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: + def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: if paired: - return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue + test_result = scipy.stats.ttest_rel(x0, x1, **kwargs) else: - return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue + test_result = scipy.stats.ttest_ind(x0, x1, **kwargs) + + return { + "p_value": test_result.pvalue, + "statistic": test_result.statistic, + } + + +class PermutationTest(SimpleComparisonBase): + """Perform a permutation test. + + The permutation test relies on another test statistic (e.g. t-statistic or your own) to obtain a p-value through + random permutations of the data and repeated generation of the test statistic. + + For paired tests, each paired observation is permuted together and distributed randomly between the two groups. For + unpaired tests, all observations are permuted independently. + + The null hypothesis for the unpaired test is that all observations come from the same underlying distribution and + have been randomly assigned to one of the samples. + + The null hypothesis for the paired permutation test is that the observations within each pair are drawn from the + same underlying distribution and that their assignment to a sample is random. + """ + + @staticmethod + def _test( + x0: np.ndarray, + x1: np.ndarray, + paired: bool, + test_statistic: type["SimpleComparisonBase"] | Callable = WilcoxonTest, + n_permutations: int = 1000, + **kwargs, + ) -> dict[str, float]: + """Perform a permutation test. + + This function relies on another test (e.g. WilcoxonTest) to generate a test statistic for each permutation. + + Args: + x0: Array with baseline values. + x1: Array with values to compare. + paired: Whether to perform a paired test + test_statistic: The class or function to generate the test statistic from permuted data. If a function is + passed, it must have the signature `test_statistic(x0, x1, paired[, axis], **kwargs)`. If it accepts the + parameter axis, vectorization will be used. + n_permutations: Number of permutations to perform. + **kwargs: kwargs passed to the permutation test function, not the test function after permutation. + + Examples: + You can use the `PermutationTest` class to perform a permutation test with a custom test statistic or an + existing test statistic like `TTest`. The test statistic must be a class that implements the `_test` method + or a function that takes the arguments `x0`, `x1`, `paired` and `**kwargs`. + + >>> from pertpy.tools import PermutationTest, TTest + >>> # Perform a permutation test with a t-statistic + >>> p_value = PermutationTest._test(x0, x1, paired=True, test=TTest, n_permutations=1000, rng=0) + >>> # Perform a permutation test with a custom test statistic + >>> p_value = PermutationTest._test(x0, x1, paired=False, test=your_custom_test_statistic) + """ + if test_statistic is PermutationTest: + raise ValueError( + "The `test_statistic` argument cannot be `PermutationTest`. Use a base test like `TTest` or a custom test." + ) + + vectorized = hasattr(test_statistic, "_test") or "axis" in signature(test_statistic).parameters + + def call_test(data_baseline, data_comparison, axis: int | None = None, **kwargs): + """Perform the actual test.""" + if not hasattr(test_statistic, "_test"): + if vectorized: + return test_statistic(data_baseline, data_comparison, paired=paired, axis=axis, **kwargs)[ + "statistic" + ] + + return test_statistic(data_baseline, data_comparison, paired=paired, **kwargs)["statistic"] + + if vectorized: + kwargs.update({"axis": axis}) + + return test_statistic._test(data_baseline, data_comparison, paired, **kwargs)["statistic"] + + test_result = scipy.stats.permutation_test( + [x0, x1], + statistic=call_test, + n_resamples=n_permutations, + permutation_type=("samples" if paired else "independent"), + vectorized=vectorized, + **kwargs, + ) + + return { + "p_value": test_result.pvalue, + "statistic": test_result.statistic, + } diff --git a/tests/tools/_differential_gene_expression/test_simple_tests.py b/tests/tools/_differential_gene_expression/test_simple_tests.py index eb2acf09..c9693eff 100644 --- a/tests/tools/_differential_gene_expression/test_simple_tests.py +++ b/tests/tools/_differential_gene_expression/test_simple_tests.py @@ -2,7 +2,7 @@ import pandas as pd import pytest from pandas.core.api import DataFrame as DataFrame -from pertpy.tools._differential_gene_expression import SimpleComparisonBase, TTest, WilcoxonTest +from pertpy.tools._differential_gene_expression import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest @pytest.mark.parametrize( @@ -61,6 +61,45 @@ def test_t(test_adata_minimal, paired_by, expected): assert actual[gene] == pytest.approx(expected[gene], abs=0.02) +@pytest.mark.parametrize( + "paired_by,expected", + [ + pytest.param( + None, + {"gene1": {"p_value": 2.13e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.96, "log_fc": -0.016}}, + id="unpaired", + ), + pytest.param( + "pairing", + {"gene1": {"p_value": 1.63e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.85, "log_fc": -0.016}}, + id="paired", + ), + ], +) +def test_permutation(test_adata_minimal, paired_by, expected): + """Test that permutation test gives the correct values. + + Reference values have been computed in R using wilcox.test + """ + for statistic in [TTest, WilcoxonTest]: + res_df = PermutationTest.compare_groups( + adata=test_adata_minimal, + column="condition", + baseline="A", + groups_to_compare="B", + paired_by=paired_by, + n_permutations=1000, + permutation_test_statistic=statistic, + test_kwargs={"rng": 0}, + ) + assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame" + actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index") + for gene in expected: + assert (actual[gene]["p_value"] < 0.05) == (expected[gene]["p_value"] < 0.05) + if actual[gene]["p_value"] < 0.05: + assert actual[gene] == pytest.approx(expected[gene], abs=0.02) + + @pytest.mark.parametrize("seed", range(10)) def test_simple_comparison_pairing(test_adata_minimal, seed): """Test that paired samples are properly matched in a paired test"""