From f99e5ac9702e33db2abfa030ace6a65851507ef5 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Tue, 18 Mar 2025 12:35:00 +0100 Subject: [PATCH 01/11] Add permutation test --- docs/usage/usage.md | 13 +- pertpy/tools/__init__.py | 2 + .../_differential_gene_expression/__init__.py | 5 +- .../_pydeseq2.py | 2 +- .../_simple_tests.py | 141 +++++++++++++++++- .../test_simple_tests.py | 28 +++- 6 files changed, 180 insertions(+), 11 deletions(-) diff --git a/docs/usage/usage.md b/docs/usage/usage.md index 58a30c03..2a5ff02f 100644 --- a/docs/usage/usage.md +++ b/docs/usage/usage.md @@ -136,6 +136,7 @@ Pertpy provides utilities to conduct differential gene expression tests through tools.EdgeR tools.WilcoxonTest tools.TTest + tools.PermutationTest tools.Statsmodels ``` @@ -563,9 +564,9 @@ including cell line annotation, bulk RNA and protein expression data. Available databases for cell line metadata: -- [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/) -- [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/) -- [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/) +- [The Cancer Dependency Map Project at Broad](https://depmap.org/portal/) +- [The Cancer Dependency Map Project at Sanger](https://depmap.sanger.ac.uk/) +- [Genomics of Drug Sensitivity in Cancer (GDSC)](https://www.cancerrxgene.org/) ### Compound @@ -573,7 +574,7 @@ The Compound module enables the retrieval of various types of information relate Available databases for compound metadata: -- [PubChem](https://pubchem.ncbi.nlm.nih.gov/) +- [PubChem](https://pubchem.ncbi.nlm.nih.gov/) ### Mechanism of Action @@ -581,7 +582,7 @@ This module aims to retrieve metadata of mechanism of action studies related to Available databases for mechanism of action metadata: -- [CLUE](https://clue.io/) +- [CLUE](https://clue.io/) ### Drug @@ -589,7 +590,7 @@ This module allows for the retrieval of Drug target information. Available databases for drug metadata: -- [chembl](https://www.ebi.ac.uk/chembl/) +- [chembl](https://www.ebi.ac.uk/chembl/) ```{eval-rst} .. autosummary:: diff --git a/pertpy/tools/__init__.py b/pertpy/tools/__init__.py index 10565cee..8322e418 100644 --- a/pertpy/tools/__init__.py +++ b/pertpy/tools/__init__.py @@ -47,6 +47,7 @@ def __init__(self, *args, **kwargs): DE_EXTRAS = ["formulaic", "pydeseq2"] EdgeR = lazy_import("pertpy.tools._differential_gene_expression", "EdgeR", DE_EXTRAS) # edgeR will be imported via rpy2 +PermutationTest = lazy_import("pertpy.tools._differential_gene_expression", "PermutationTest", DE_EXTRAS) PyDESeq2 = lazy_import("pertpy.tools._differential_gene_expression", "PyDESeq2", DE_EXTRAS) Statsmodels = lazy_import("pertpy.tools._differential_gene_expression", "Statsmodels", DE_EXTRAS + ["statsmodels"]) TTest = lazy_import("pertpy.tools._differential_gene_expression", "TTest", DE_EXTRAS) @@ -62,6 +63,7 @@ def __init__(self, *args, **kwargs): "PyDESeq2", "WilcoxonTest", "TTest", + "PermutationTest", "Statsmodels", "DistanceTest", "Distance", diff --git a/pertpy/tools/_differential_gene_expression/__init__.py b/pertpy/tools/_differential_gene_expression/__init__.py index 6cc925dd..19a3383c 100644 --- a/pertpy/tools/_differential_gene_expression/__init__.py +++ b/pertpy/tools/_differential_gene_expression/__init__.py @@ -2,7 +2,7 @@ from ._dge_comparison import DGEEVAL from ._edger import EdgeR from ._pydeseq2 import PyDESeq2 -from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest +from ._simple_tests import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest from ._statsmodels import Statsmodels __all__ = [ @@ -14,6 +14,7 @@ "SimpleComparisonBase", "WilcoxonTest", "TTest", + "PermutationTest", ] -AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest] +AVAILABLE_METHODS = [Statsmodels, EdgeR, PyDESeq2, WilcoxonTest, TTest, PermutationTest] diff --git a/pertpy/tools/_differential_gene_expression/_pydeseq2.py b/pertpy/tools/_differential_gene_expression/_pydeseq2.py index d4360e28..f191debc 100644 --- a/pertpy/tools/_differential_gene_expression/_pydeseq2.py +++ b/pertpy/tools/_differential_gene_expression/_pydeseq2.py @@ -42,7 +42,7 @@ def fit(self, **kwargs) -> pd.DataFrame: **kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed. """ try: - usable_cpus = len(os.sched_getaffinity(0)) + usable_cpus = len(os.sched_getaffinity(0)) # type: ignore # os.sched_getaffinity is not available on Windows and macOS except AttributeError: usable_cpus = os.cpu_count() diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 824569f9..1cb1c4ca 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -4,12 +4,14 @@ from abc import abstractmethod from collections.abc import Mapping, Sequence from types import MappingProxyType +from typing import Union import numpy as np import pandas as pd import scipy.stats import statsmodels from anndata import AnnData +from joblib import Parallel, delayed from pandas.core.api import DataFrame as DataFrame from scipy.sparse import diags, issparse from tqdm.auto import tqdm @@ -152,7 +154,7 @@ def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: class TTest(SimpleComparisonBase): - """Perform a unpaired or paired T-test""" + """Perform a unpaired or paired T-test.""" @staticmethod def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: @@ -160,3 +162,140 @@ def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue else: return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue + + +class PermutationTest(SimpleComparisonBase): + """Perform a permutation test. + + The permutation test relies on another test (e.g. TTest) to perform the actual comparison + based on permuted data. The p-value is then calculated based on the distribution of the test + statistic under the null hypothesis. + + For paired tests, each paired observation is permuted together and distributed randoml between + the two groups. For unpaired tests, all observations are permuted independently. + + The null hypothesis for the unpaired test is that all observations come from the same underlying + distribution and have been randomly assigned to one of the samples. + + The null hypothesis for the paired permutation test is that the observations within each pair are + drawn from the same underlying distribution and that their assignment to a sample is random. + """ + + @staticmethod + def _test( + x0: np.ndarray, + x1: np.ndarray, + paired: bool, + test: type["SimpleComparisonBase"] = TTest, + n_permutations: int = 100, + **kwargs, + ) -> float: + """Perform a permutation test. + + Args: + x0: Array with baseline values. + x1: Array with values to compare. + paired: Indicates whether to perform a paired test + test: The test to use for the actual comparison. + n_permutations: Number of permutations to perform. + **kwargs: kwargs passed to the test function + """ + + def call_test(x0, x1, **kwargs): + """Perform the actual test.""" + return test._test(x0, x1, paired, **kwargs) + + if paired: + return scipy.stats.permutation_test( + [x0, x1], + statistic=call_test, + n_resamples=n_permutations, + permutation_type="samples", + **kwargs, + ).pvalue + else: + return scipy.stats.permutation_test( + [x0, x1], + statistic=call_test, + n_resamples=n_permutations, + permutation_type="independent", + **kwargs, + ).pvalue + + @classmethod + def compare_groups( + cls, + adata: AnnData, + column: str, + baseline: str, + groups_to_compare: str | Sequence[str], + test: type["SimpleComparisonBase"] = TTest, + n_permutations: int = 100, + *, + paired_by: str | None = None, + mask: str | None = None, + layer: str | None = None, + fit_kwargs: Mapping = MappingProxyType({}), + test_kwargs: Mapping = MappingProxyType({}), + ) -> DataFrame: + """Perform a comparison between groups using a permutation test. + + Args: + adata: Annotated data object. + column: Column in `adata.obs` that contains the groups to compare. + baseline: Reference group. + groups_to_compare: Groups to compare against the baseline. + test: The test to use for the actual comparison after permutation. Default is TTest. + n_permutations: Number of permutations to perform. + paired_by: Column in `adata.obs` to use for pairing. + mask: Mask to apply to the data. + layer: Layer to use for the comparison. + fit_kwargs: Additional kwargs passed to the test function. + test_kwargs: Additional kwargs passed to the test function. + """ + if len(fit_kwargs): + warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) + paired = paired_by is not None + model = cls(adata, mask=mask, layer=layer) + if groups_to_compare is None: + # compare against all other + groups_to_compare = sorted(set(model.adata.obs[column]) - {baseline}) + if isinstance(groups_to_compare, str): + groups_to_compare = [groups_to_compare] + + def _get_idx(column, value): + mask = model.adata.obs[column] == value + if paired: + dummies = pd.get_dummies(model.adata.obs[paired_by], sparse=True).sparse.to_coo().tocsr() + if not np.all(np.sum(dummies, axis=0) == 2): + raise ValueError("Pairing is only possible with exactly two values per group") + # Use matrix multiplication to only retreive those dummy entries that are associated with the current `value`. + # Convert to COO matrix to get rows/cols + # row indices refers to the indices of rows that have `column == value` (equivalent to np.where(mask)[0]) + # col indices refers to the numeric index of each "pair" in obs_names + ind_mat = diags(mask.values, dtype=bool) @ dummies + if not np.all(np.sum(ind_mat, axis=0) == 1): + raise ValueError("Pairing is only possible with exactly two values per group") + ind_mat = ind_mat.tocoo() + return ind_mat.row[np.argsort(ind_mat.col)] + else: + return np.where(mask)[0] + + test_kwargs.__setattr__("test", test) + test_kwargs.__setattr__("n_permutations", n_permutations) + + res_dfs = [] + baseline_idx = _get_idx(column, baseline) + + comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] + res_dfs = Parallel(n_jobs=-1)( + delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs) + for comparison_idx in comparison_indices + ) + res_dfs = [ + df.assign( + comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}", + ) + for df, group_to_compare in zip(res_dfs, groups_to_compare, strict=False) + ] + return fdr_correction(pd.concat(res_dfs)) diff --git a/tests/tools/_differential_gene_expression/test_simple_tests.py b/tests/tools/_differential_gene_expression/test_simple_tests.py index eb2acf09..0b2869b1 100644 --- a/tests/tools/_differential_gene_expression/test_simple_tests.py +++ b/tests/tools/_differential_gene_expression/test_simple_tests.py @@ -2,7 +2,7 @@ import pandas as pd import pytest from pandas.core.api import DataFrame as DataFrame -from pertpy.tools._differential_gene_expression import SimpleComparisonBase, TTest, WilcoxonTest +from pertpy.tools._differential_gene_expression import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest @pytest.mark.parametrize( @@ -61,6 +61,32 @@ def test_t(test_adata_minimal, paired_by, expected): assert actual[gene] == pytest.approx(expected[gene], abs=0.02) +@pytest.mark.parametrize( + "paired_by,expected", + [ + pytest.param( + None, + {"gene1": {"p_value": 2.13e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.96, "log_fc": -0.016}}, + id="unpaired", + ), + pytest.param( + "pairing", + {"gene1": {"p_value": 1.63e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.85, "log_fc": -0.016}}, + id="paired", + ), + ], +) +def test_permutation(test_adata_minimal, paired_by, expected): + """Test that t-test gives the correct values.""" + res_df = PermutationTest.compare_groups( + adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by + ) + assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame" + actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index") + for gene in expected: + assert actual[gene] == pytest.approx(expected[gene], abs=0.02) + + @pytest.mark.parametrize("seed", range(10)) def test_simple_comparison_pairing(test_adata_minimal, seed): """Test that paired samples are properly matched in a paired test""" From 66f228e2d21afd45797f31f671b5af19cbaf7da7 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Tue, 18 Mar 2025 13:50:10 +0100 Subject: [PATCH 02/11] Fix test kwargs update --- pertpy/tools/_differential_gene_expression/_simple_tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 1cb1c4ca..93bdc9d9 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -281,15 +281,15 @@ def _get_idx(column, value): else: return np.where(mask)[0] - test_kwargs.__setattr__("test", test) - test_kwargs.__setattr__("n_permutations", n_permutations) + test_kwargs_mutable = dict(test_kwargs) + test_kwargs_mutable.update({"test": test, "n_permutations": n_permutations}) res_dfs = [] baseline_idx = _get_idx(column, baseline) comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] res_dfs = Parallel(n_jobs=-1)( - delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs) + delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs_mutable) for comparison_idx in comparison_indices ) res_dfs = [ From 74e3ec1c9b7941e05fbc8890de9e59c26945774a Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Tue, 18 Mar 2025 15:26:23 +0100 Subject: [PATCH 03/11] Add n_jobs argument and change test to check for significance agreement with both TTest and Wilcoxontest and add seed --- .../_simple_tests.py | 17 +++++++---- .../test_simple_tests.py | 28 +++++++++++++------ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 93bdc9d9..f95c86cf 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -4,7 +4,6 @@ from abc import abstractmethod from collections.abc import Mapping, Sequence from types import MappingProxyType -from typing import Union import numpy as np import pandas as pd @@ -167,7 +166,7 @@ def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: class PermutationTest(SimpleComparisonBase): """Perform a permutation test. - The permutation test relies on another test (e.g. TTest) to perform the actual comparison + The permutation test relies on another test (e.g. WilcoxonTest) to perform the actual comparison based on permuted data. The p-value is then calculated based on the distribution of the test statistic under the null hypothesis. @@ -186,8 +185,9 @@ def _test( x0: np.ndarray, x1: np.ndarray, paired: bool, - test: type["SimpleComparisonBase"] = TTest, + test: type["SimpleComparisonBase"] = WilcoxonTest, n_permutations: int = 100, + seed: int = 0, **kwargs, ) -> float: """Perform a permutation test. @@ -211,6 +211,7 @@ def call_test(x0, x1, **kwargs): statistic=call_test, n_resamples=n_permutations, permutation_type="samples", + rng=seed, **kwargs, ).pvalue else: @@ -219,6 +220,7 @@ def call_test(x0, x1, **kwargs): statistic=call_test, n_resamples=n_permutations, permutation_type="independent", + rng=seed, **kwargs, ).pvalue @@ -229,8 +231,10 @@ def compare_groups( column: str, baseline: str, groups_to_compare: str | Sequence[str], - test: type["SimpleComparisonBase"] = TTest, + test: type["SimpleComparisonBase"] = WilcoxonTest, n_permutations: int = 100, + n_jobs: int = -1, + seed: int = 0, *, paired_by: str | None = None, mask: str | None = None, @@ -247,6 +251,7 @@ def compare_groups( groups_to_compare: Groups to compare against the baseline. test: The test to use for the actual comparison after permutation. Default is TTest. n_permutations: Number of permutations to perform. + n_jobs: Number of parallel jobs to use. paired_by: Column in `adata.obs` to use for pairing. mask: Mask to apply to the data. layer: Layer to use for the comparison. @@ -282,13 +287,13 @@ def _get_idx(column, value): return np.where(mask)[0] test_kwargs_mutable = dict(test_kwargs) - test_kwargs_mutable.update({"test": test, "n_permutations": n_permutations}) + test_kwargs_mutable.update({"test": test, "n_permutations": n_permutations, "seed": seed}) res_dfs = [] baseline_idx = _get_idx(column, baseline) comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] - res_dfs = Parallel(n_jobs=-1)( + res_dfs = Parallel(n_jobs=n_jobs)( delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs_mutable) for comparison_idx in comparison_indices ) diff --git a/tests/tools/_differential_gene_expression/test_simple_tests.py b/tests/tools/_differential_gene_expression/test_simple_tests.py index 0b2869b1..df9c53db 100644 --- a/tests/tools/_differential_gene_expression/test_simple_tests.py +++ b/tests/tools/_differential_gene_expression/test_simple_tests.py @@ -77,14 +77,26 @@ def test_t(test_adata_minimal, paired_by, expected): ], ) def test_permutation(test_adata_minimal, paired_by, expected): - """Test that t-test gives the correct values.""" - res_df = PermutationTest.compare_groups( - adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by - ) - assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame" - actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index") - for gene in expected: - assert actual[gene] == pytest.approx(expected[gene], abs=0.02) + """Test that permutation test gives the correct values. + + Reference values have been computed in R using wilcox.test + """ + for test in [TTest, WilcoxonTest]: + res_df = PermutationTest.compare_groups( + adata=test_adata_minimal, + column="condition", + baseline="A", + groups_to_compare="B", + paired_by=paired_by, + n_permutations=100, + test=test, + seed=0, + ) + assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame" + actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index") + for gene in expected: + assert (expected[gene]["p_value"] < 0.05) == (actual[gene]["p_value"] < 0.05) + assert actual[gene]["log_fc"] == pytest.approx(expected[gene]["log_fc"], abs=0.02) @pytest.mark.parametrize("seed", range(10)) From 64b71143b2c6c9f71f8c501d31d7528191fc5328 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Tue, 18 Mar 2025 17:04:48 +0100 Subject: [PATCH 04/11] Simplify and generalize compare_groups by adding most important permutation arguments, passing others through kwargs --- .../_simple_tests.py | 166 ++++++------------ .../test_simple_tests.py | 4 +- 2 files changed, 56 insertions(+), 114 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index f95c86cf..9fcb1ae1 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -4,6 +4,7 @@ from abc import abstractmethod from collections.abc import Mapping, Sequence from types import MappingProxyType +from typing import Optional import numpy as np import pandas as pd @@ -95,9 +96,28 @@ def compare_groups( paired_by: str | None = None, mask: str | None = None, layer: str | None = None, + n_permutations: int = 100, + permutation_test: type["SimpleComparisonBase"] | None = None, fit_kwargs: Mapping = MappingProxyType({}), test_kwargs: Mapping = MappingProxyType({}), + n_jobs: int = -1, ) -> DataFrame: + """Perform a comparison between groups. + + Args: + adata (AnnData): Data with observations to compare. + column (str): Column in `adata.obs` that contains the groups to compare. + baseline (str): Reference group. + groups_to_compare (str | Sequence[str]): Groups to compare against the baseline. If None, all other groups are compared. + paired_by (str | None): Column in `adata.obs` to use for pairing. If None, an unpaired test is performed. + mask (str | None): Mask to apply to the data. + layer (str | None): Layer to use for the comparison. + n_permutations (int): Number of permutations to perform if a permutation test is used. + permutation_test (type[SimpleComparisonBase] | None): Test to use after permutation if a permutation test is used. + fit_kwargs (Mapping): Not used for simple tests. + test_kwargs (Mapping): Additional kwargs passed to the test function. + n_jobs (int): Number of parallel jobs to use. + """ if len(fit_kwargs): warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) paired = paired_by is not None @@ -128,13 +148,25 @@ def _get_idx(column, value): res_dfs = [] baseline_idx = _get_idx(column, baseline) - for group_to_compare in groups_to_compare: - comparison_idx = _get_idx(column, group_to_compare) - res_dfs.append( - model._compare_single_group(baseline_idx, comparison_idx, paired=paired, **test_kwargs).assign( - comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}" - ) + + if permutation_test: + test_kwargs = dict(test_kwargs) + test_kwargs["n_permutations"] = n_permutations + test_kwargs["test"] = permutation_test + elif permutation_test is None and cls.__name__ == "PermutationTest": + raise ValueError("PermutationTest requires a permutation_test argument") + + comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] + res_dfs = Parallel(n_jobs=n_jobs)( + delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs) + for comparison_idx in comparison_indices + ) + res_dfs = [ + df.assign( + comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}", ) + for df, group_to_compare in zip(res_dfs, groups_to_compare, strict=False) + ] return fdr_correction(pd.concat(res_dfs)) @@ -170,7 +202,7 @@ class PermutationTest(SimpleComparisonBase): based on permuted data. The p-value is then calculated based on the distribution of the test statistic under the null hypothesis. - For paired tests, each paired observation is permuted together and distributed randoml between + For paired tests, each paired observation is permuted together and distributed randomly between the two groups. For unpaired tests, all observations are permuted independently. The null hypothesis for the unpaired test is that all observations come from the same underlying @@ -187,7 +219,6 @@ def _test( paired: bool, test: type["SimpleComparisonBase"] = WilcoxonTest, n_permutations: int = 100, - seed: int = 0, **kwargs, ) -> float: """Perform a permutation test. @@ -195,112 +226,23 @@ def _test( Args: x0: Array with baseline values. x1: Array with values to compare. - paired: Indicates whether to perform a paired test + paired: Whether to perform a paired test test: The test to use for the actual comparison. n_permutations: Number of permutations to perform. - **kwargs: kwargs passed to the test function + **kwargs: kwargs passed to the permutation test function, not the test function after permutation. """ - def call_test(x0, x1, **kwargs): + def call_test(data_baseline, data_comparison, **kwargs): """Perform the actual test.""" - return test._test(x0, x1, paired, **kwargs) - - if paired: - return scipy.stats.permutation_test( - [x0, x1], - statistic=call_test, - n_resamples=n_permutations, - permutation_type="samples", - rng=seed, - **kwargs, - ).pvalue - else: - return scipy.stats.permutation_test( - [x0, x1], - statistic=call_test, - n_resamples=n_permutations, - permutation_type="independent", - rng=seed, - **kwargs, - ).pvalue - - @classmethod - def compare_groups( - cls, - adata: AnnData, - column: str, - baseline: str, - groups_to_compare: str | Sequence[str], - test: type["SimpleComparisonBase"] = WilcoxonTest, - n_permutations: int = 100, - n_jobs: int = -1, - seed: int = 0, - *, - paired_by: str | None = None, - mask: str | None = None, - layer: str | None = None, - fit_kwargs: Mapping = MappingProxyType({}), - test_kwargs: Mapping = MappingProxyType({}), - ) -> DataFrame: - """Perform a comparison between groups using a permutation test. - - Args: - adata: Annotated data object. - column: Column in `adata.obs` that contains the groups to compare. - baseline: Reference group. - groups_to_compare: Groups to compare against the baseline. - test: The test to use for the actual comparison after permutation. Default is TTest. - n_permutations: Number of permutations to perform. - n_jobs: Number of parallel jobs to use. - paired_by: Column in `adata.obs` to use for pairing. - mask: Mask to apply to the data. - layer: Layer to use for the comparison. - fit_kwargs: Additional kwargs passed to the test function. - test_kwargs: Additional kwargs passed to the test function. - """ - if len(fit_kwargs): - warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) - paired = paired_by is not None - model = cls(adata, mask=mask, layer=layer) - if groups_to_compare is None: - # compare against all other - groups_to_compare = sorted(set(model.adata.obs[column]) - {baseline}) - if isinstance(groups_to_compare, str): - groups_to_compare = [groups_to_compare] - - def _get_idx(column, value): - mask = model.adata.obs[column] == value - if paired: - dummies = pd.get_dummies(model.adata.obs[paired_by], sparse=True).sparse.to_coo().tocsr() - if not np.all(np.sum(dummies, axis=0) == 2): - raise ValueError("Pairing is only possible with exactly two values per group") - # Use matrix multiplication to only retreive those dummy entries that are associated with the current `value`. - # Convert to COO matrix to get rows/cols - # row indices refers to the indices of rows that have `column == value` (equivalent to np.where(mask)[0]) - # col indices refers to the numeric index of each "pair" in obs_names - ind_mat = diags(mask.values, dtype=bool) @ dummies - if not np.all(np.sum(ind_mat, axis=0) == 1): - raise ValueError("Pairing is only possible with exactly two values per group") - ind_mat = ind_mat.tocoo() - return ind_mat.row[np.argsort(ind_mat.col)] - else: - return np.where(mask)[0] - - test_kwargs_mutable = dict(test_kwargs) - test_kwargs_mutable.update({"test": test, "n_permutations": n_permutations, "seed": seed}) - - res_dfs = [] - baseline_idx = _get_idx(column, baseline) - - comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] - res_dfs = Parallel(n_jobs=n_jobs)( - delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs_mutable) - for comparison_idx in comparison_indices - ) - res_dfs = [ - df.assign( - comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}", - ) - for df, group_to_compare in zip(res_dfs, groups_to_compare, strict=False) - ] - return fdr_correction(pd.concat(res_dfs)) + return test._test(data_baseline, data_comparison, paired, **kwargs) + + # Set a seed for reproducibility if not already set + kwargs["rng"] = kwargs.get("rng", 0) + + return scipy.stats.permutation_test( + [x0, x1], + statistic=call_test, + n_resamples=n_permutations, + permutation_type=("samples" if paired else "independent"), + **kwargs, + ).pvalue diff --git a/tests/tools/_differential_gene_expression/test_simple_tests.py b/tests/tools/_differential_gene_expression/test_simple_tests.py index df9c53db..6e7d898b 100644 --- a/tests/tools/_differential_gene_expression/test_simple_tests.py +++ b/tests/tools/_differential_gene_expression/test_simple_tests.py @@ -89,8 +89,8 @@ def test_permutation(test_adata_minimal, paired_by, expected): groups_to_compare="B", paired_by=paired_by, n_permutations=100, - test=test, - seed=0, + permutation_test=test, + test_kwargs={"rng": 0}, ) assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame" actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index") From 325f1db9457c0a0d7676f7f59470402d6aceb318 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Tue, 18 Mar 2025 17:30:36 +0100 Subject: [PATCH 05/11] Make permutation_test argument optional but raise warning if not provided --- pertpy/tools/_differential_gene_expression/_simple_tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 9fcb1ae1..1ef2593f 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -12,6 +12,7 @@ import statsmodels from anndata import AnnData from joblib import Parallel, delayed +from lamin_utils import logger from pandas.core.api import DataFrame as DataFrame from scipy.sparse import diags, issparse from tqdm.auto import tqdm @@ -151,10 +152,9 @@ def _get_idx(column, value): if permutation_test: test_kwargs = dict(test_kwargs) - test_kwargs["n_permutations"] = n_permutations - test_kwargs["test"] = permutation_test + test_kwargs.update({"test": cls, "n_permutations": n_permutations}) elif permutation_test is None and cls.__name__ == "PermutationTest": - raise ValueError("PermutationTest requires a permutation_test argument") + logger.warning("No permutation test specified. Using WilcoxonTest as default.") comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] res_dfs = Parallel(n_jobs=n_jobs)( From 3e81976ee47f46ccb28dc78277ec397816c34437 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Tue, 18 Mar 2025 22:03:27 +0100 Subject: [PATCH 06/11] Make test case a bit stricter again for significant values, enable returning statistic from tests and fix bug where the permutation_test was not applied --- .../_simple_tests.py | 64 +++++++++++++------ .../test_simple_tests.py | 11 ++-- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 1ef2593f..b3e7e735 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -2,9 +2,9 @@ import warnings from abc import abstractmethod -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from types import MappingProxyType -from typing import Optional +from typing import Optional, Union import numpy as np import pandas as pd @@ -97,7 +97,7 @@ def compare_groups( paired_by: str | None = None, mask: str | None = None, layer: str | None = None, - n_permutations: int = 100, + n_permutations: int = 1000, permutation_test: type["SimpleComparisonBase"] | None = None, fit_kwargs: Mapping = MappingProxyType({}), test_kwargs: Mapping = MappingProxyType({}), @@ -152,7 +152,7 @@ def _get_idx(column, value): if permutation_test: test_kwargs = dict(test_kwargs) - test_kwargs.update({"test": cls, "n_permutations": n_permutations}) + test_kwargs.update({"test": permutation_test, "n_permutations": n_permutations}) elif permutation_test is None and cls.__name__ == "PermutationTest": logger.warning("No permutation test specified. Using WilcoxonTest as default.") @@ -177,22 +177,22 @@ class WilcoxonTest(SimpleComparisonBase): """ @staticmethod - def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: + def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, return_attribute: str = "pvalue", **kwargs) -> float: if paired: - return scipy.stats.wilcoxon(x0, x1, **kwargs).pvalue + return scipy.stats.wilcoxon(x0, x1, **kwargs).__getattribute__(return_attribute) else: - return scipy.stats.mannwhitneyu(x0, x1, **kwargs).pvalue + return scipy.stats.mannwhitneyu(x0, x1, **kwargs).__getattribute__(return_attribute) class TTest(SimpleComparisonBase): """Perform a unpaired or paired T-test.""" @staticmethod - def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: + def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, return_attribute: str = "pvalue", **kwargs) -> float: if paired: - return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue + return scipy.stats.ttest_rel(x0, x1, **kwargs).__getattribute__(return_attribute) else: - return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue + return scipy.stats.ttest_ind(x0, x1, **kwargs).__getattribute__(return_attribute) class PermutationTest(SimpleComparisonBase): @@ -217,32 +217,60 @@ def _test( x0: np.ndarray, x1: np.ndarray, paired: bool, - test: type["SimpleComparisonBase"] = WilcoxonTest, - n_permutations: int = 100, + test: type["SimpleComparisonBase"] | Callable = WilcoxonTest, + n_permutations: int = 1000, + return_attribute: str = "pvalue", **kwargs, ) -> float: """Perform a permutation test. + This function relies on another test (e.g. WilcoxonTest) to generate a test statistic for each permutation. + + .. code-block:: python + from pertpy.tools import PermutationTest, WilcoxonTest + + # Using rank-sum statistic + p_value = PermutationTest._test(x0, x1, paired=True, test=WilcoxonTest, n_permutations=1000, rng=0) + + + # Using a custom test statistic + def compare_means(x0, x1, paired): + # paired logic not implemented here + return np.mean(x1) - np.mean(x0) + + + p_value = PermutationTest._test(x0, x1, paired=False, test=compare_means, n_permutations=1000, rng=0) + Args: x0: Array with baseline values. x1: Array with values to compare. paired: Whether to perform a paired test - test: The test to use for the actual comparison. + test: The class or function to generate the test statistic from permuted data. n_permutations: Number of permutations to perform. + return_attribute: Attribute to return from the test statistic. **kwargs: kwargs passed to the permutation test function, not the test function after permutation. """ + if test is PermutationTest: + raise ValueError( + "The `test` argument cannot be `PermutationTest`. Use a base test like `WilcoxonTest` or `TTest`." + ) - def call_test(data_baseline, data_comparison, **kwargs): + def call_test(data_baseline, data_comparison, axis: int | None = None, **kwargs): """Perform the actual test.""" - return test._test(data_baseline, data_comparison, paired, **kwargs) + # Setting the axis allows the operation to be vectorized + if axis is not None: + kwargs.update({"axis": axis}) + + if not hasattr(test, "_test"): + return test(data_baseline, data_comparison, paired, **kwargs) - # Set a seed for reproducibility if not already set - kwargs["rng"] = kwargs.get("rng", 0) + return test._test(data_baseline, data_comparison, paired, return_attribute="statistic", **kwargs) return scipy.stats.permutation_test( [x0, x1], statistic=call_test, n_resamples=n_permutations, permutation_type=("samples" if paired else "independent"), + vectorized=hasattr(test, "_test"), **kwargs, - ).pvalue + ).__getattribute__(return_attribute) diff --git a/tests/tools/_differential_gene_expression/test_simple_tests.py b/tests/tools/_differential_gene_expression/test_simple_tests.py index 6e7d898b..ec7771ea 100644 --- a/tests/tools/_differential_gene_expression/test_simple_tests.py +++ b/tests/tools/_differential_gene_expression/test_simple_tests.py @@ -81,22 +81,23 @@ def test_permutation(test_adata_minimal, paired_by, expected): Reference values have been computed in R using wilcox.test """ - for test in [TTest, WilcoxonTest]: + for permutation_test in [TTest, WilcoxonTest]: res_df = PermutationTest.compare_groups( adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by, - n_permutations=100, - permutation_test=test, + n_permutations=200, + permutation_test=permutation_test, test_kwargs={"rng": 0}, ) assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame" actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index") for gene in expected: - assert (expected[gene]["p_value"] < 0.05) == (actual[gene]["p_value"] < 0.05) - assert actual[gene]["log_fc"] == pytest.approx(expected[gene]["log_fc"], abs=0.02) + assert (actual[gene]["p_value"] < 0.05) == (expected[gene]["p_value"] < 0.05) + if actual[gene]["p_value"] < 0.05: + assert actual[gene] == pytest.approx(expected[gene], abs=0.02) @pytest.mark.parametrize("seed", range(10)) From 14736b3a9cc97a97001b0f02b82284a40c641511 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Tue, 18 Mar 2025 22:17:42 +0100 Subject: [PATCH 07/11] Remove unnecessary import --- pertpy/tools/_differential_gene_expression/_simple_tests.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index b3e7e735..118f80b0 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -4,7 +4,6 @@ from abc import abstractmethod from collections.abc import Callable, Mapping, Sequence from types import MappingProxyType -from typing import Optional, Union import numpy as np import pandas as pd From 5873b87089acb3066f3dbf5b6407e2529711a3c1 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Wed, 19 Mar 2025 10:22:14 +0100 Subject: [PATCH 08/11] Remove parallelization and return statistic and p-value everywhere --- .../_simple_tests.py | 107 ++++++++++-------- 1 file changed, 61 insertions(+), 46 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 118f80b0..117639a8 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -3,14 +3,15 @@ import warnings from abc import abstractmethod from collections.abc import Callable, Mapping, Sequence +from inspect import signature from types import MappingProxyType +from typing import Any import numpy as np import pandas as pd import scipy.stats import statsmodels from anndata import AnnData -from joblib import Parallel, delayed from lamin_utils import logger from pandas.core.api import DataFrame as DataFrame from scipy.sparse import diags, issparse @@ -35,7 +36,7 @@ def fdr_correction( class SimpleComparisonBase(MethodBase): @staticmethod @abstractmethod - def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: + def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: """Perform a statistical test between values in x0 and x1. If `paired` is True, x0 and x1 must be of the same length and ordered such that @@ -73,16 +74,16 @@ def _compare_single_group( x0 = x0.tocsc() x1 = x1.tocsc() - res = [] + res: list[dict[str, float | np.ndarray]] = [] for var in tqdm(self.adata.var_names): tmp_x0 = x0[:, self.adata.var_names == var] tmp_x0 = np.asarray(tmp_x0.todense()).flatten() if issparse(tmp_x0) else tmp_x0.flatten() tmp_x1 = x1[:, self.adata.var_names == var] tmp_x1 = np.asarray(tmp_x1.todense()).flatten() if issparse(tmp_x1) else tmp_x1.flatten() - pval = self._test(tmp_x0, tmp_x1, paired, **kwargs) + test_result = self._test(tmp_x0, tmp_x1, paired, **kwargs) mean_x0 = np.mean(tmp_x0) mean_x1 = np.mean(tmp_x1) - res.append({"variable": var, "p_value": pval, "log_fc": np.log2(mean_x1) - np.log2(mean_x0)}) + res.append({"variable": var, "log_fc": np.log2(mean_x1) - np.log2(mean_x0)}, **test_result) return pd.DataFrame(res).sort_values("p_value") @classmethod @@ -97,7 +98,7 @@ def compare_groups( mask: str | None = None, layer: str | None = None, n_permutations: int = 1000, - permutation_test: type["SimpleComparisonBase"] | None = None, + permutation_test_statistic: type["SimpleComparisonBase"] | None = None, fit_kwargs: Mapping = MappingProxyType({}), test_kwargs: Mapping = MappingProxyType({}), n_jobs: int = -1, @@ -113,7 +114,7 @@ def compare_groups( mask (str | None): Mask to apply to the data. layer (str | None): Layer to use for the comparison. n_permutations (int): Number of permutations to perform if a permutation test is used. - permutation_test (type[SimpleComparisonBase] | None): Test to use after permutation if a permutation test is used. + permutation_test_statistic (type[SimpleComparisonBase] | None): Test to use after permutation if a permutation test is used. fit_kwargs (Mapping): Not used for simple tests. test_kwargs (Mapping): Additional kwargs passed to the test function. n_jobs (int): Number of parallel jobs to use. @@ -146,26 +147,21 @@ def _get_idx(column, value): else: return np.where(mask)[0] + if permutation_test_statistic: + test_kwargs = dict(test_kwargs) + test_kwargs.update({"test_statistic": permutation_test_statistic, "n_permutations": n_permutations}) + elif permutation_test_statistic is None and cls.__name__ == "PermutationTest": + logger.warning("No permutation test statistic specified. Using TTest statistic as default.") + res_dfs = [] baseline_idx = _get_idx(column, baseline) - - if permutation_test: - test_kwargs = dict(test_kwargs) - test_kwargs.update({"test": permutation_test, "n_permutations": n_permutations}) - elif permutation_test is None and cls.__name__ == "PermutationTest": - logger.warning("No permutation test specified. Using WilcoxonTest as default.") - - comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] - res_dfs = Parallel(n_jobs=n_jobs)( - delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs) - for comparison_idx in comparison_indices - ) - res_dfs = [ - df.assign( - comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}", + for group_to_compare in groups_to_compare: + comparison_idx = _get_idx(column, group_to_compare) + res_dfs.append( + model._compare_single_group(baseline_idx, comparison_idx, paired=paired, **test_kwargs).assign( + comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}" + ) ) - for df, group_to_compare in zip(res_dfs, groups_to_compare, strict=False) - ] return fdr_correction(pd.concat(res_dfs)) @@ -176,22 +172,34 @@ class WilcoxonTest(SimpleComparisonBase): """ @staticmethod - def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, return_attribute: str = "pvalue", **kwargs) -> float: + def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: + """Perform an unpaired or paired Wilcoxon/Mann-Whitney-U test.""" + if paired: - return scipy.stats.wilcoxon(x0, x1, **kwargs).__getattribute__(return_attribute) + test_result = scipy.stats.wilcoxon(x0, x1, **kwargs) else: - return scipy.stats.mannwhitneyu(x0, x1, **kwargs).__getattribute__(return_attribute) + test_result = scipy.stats.mannwhitneyu(x0, x1, **kwargs) + + return { + "p_value": test_result.pvalue, + "statistic": test_result.statistic, + } class TTest(SimpleComparisonBase): """Perform a unpaired or paired T-test.""" @staticmethod - def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, return_attribute: str = "pvalue", **kwargs) -> float: + def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]: if paired: - return scipy.stats.ttest_rel(x0, x1, **kwargs).__getattribute__(return_attribute) + test_result = scipy.stats.ttest_rel(x0, x1, **kwargs) else: - return scipy.stats.ttest_ind(x0, x1, **kwargs).__getattribute__(return_attribute) + test_result = scipy.stats.ttest_ind(x0, x1, **kwargs) + + return { + "p_value": test_result.pvalue, + "statistic": test_result.statistic, + } class PermutationTest(SimpleComparisonBase): @@ -216,11 +224,10 @@ def _test( x0: np.ndarray, x1: np.ndarray, paired: bool, - test: type["SimpleComparisonBase"] | Callable = WilcoxonTest, + test_statistic: type["SimpleComparisonBase"] | Callable = WilcoxonTest, n_permutations: int = 1000, - return_attribute: str = "pvalue", **kwargs, - ) -> float: + ) -> dict[str, float]: """Perform a permutation test. This function relies on another test (e.g. WilcoxonTest) to generate a test statistic for each permutation. @@ -244,32 +251,40 @@ def compare_means(x0, x1, paired): x0: Array with baseline values. x1: Array with values to compare. paired: Whether to perform a paired test - test: The class or function to generate the test statistic from permuted data. + test_statistic: The class or function to generate the test statistic from permuted data. If a function is passed, it must have the signature `test_statistic(x0, x1, paired[, axis], **kwargs)`. If it accepts the parameter axis, vectorization will be used. n_permutations: Number of permutations to perform. - return_attribute: Attribute to return from the test statistic. **kwargs: kwargs passed to the permutation test function, not the test function after permutation. """ - if test is PermutationTest: + if test_statistic is PermutationTest: raise ValueError( - "The `test` argument cannot be `PermutationTest`. Use a base test like `WilcoxonTest` or `TTest`." + "The `test_statistic` argument cannot be `PermutationTest`. Use a base test like `WilcoxonTest` or `TTest`." ) + vectorized = hasattr(test_statistic, "_test") or "axis" in signature(test_statistic).parameters + def call_test(data_baseline, data_comparison, axis: int | None = None, **kwargs): """Perform the actual test.""" - # Setting the axis allows the operation to be vectorized - if axis is not None: - kwargs.update({"axis": axis}) + if not hasattr(test_statistic, "_test"): + if vectorized: + return test_statistic(data_baseline, data_comparison, paired, axis, **kwargs) - if not hasattr(test, "_test"): - return test(data_baseline, data_comparison, paired, **kwargs) + return test_statistic(data_baseline, data_comparison, paired, **kwargs) - return test._test(data_baseline, data_comparison, paired, return_attribute="statistic", **kwargs) + if vectorized: + kwargs.update({"axis": axis}) - return scipy.stats.permutation_test( + return test_statistic._test(data_baseline, data_comparison, paired, **kwargs) + + test_result = scipy.stats.permutation_test( [x0, x1], statistic=call_test, n_resamples=n_permutations, permutation_type=("samples" if paired else "independent"), - vectorized=hasattr(test, "_test"), + vectorized=vectorized, **kwargs, - ).__getattribute__(return_attribute) + ) + + return { + "p_value": test_result.pvalue, + "statistic": test_result.statistic, + } From 676b4f00e65f6e1acc1000d6c01f0bb8bc4630a1 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Wed, 19 Mar 2025 10:24:05 +0100 Subject: [PATCH 09/11] Remove parallelization and return statistic and p-value everywhere --- .../_differential_gene_expression/_simple_tests.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 117639a8..9bda95a9 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -109,12 +109,14 @@ def compare_groups( adata (AnnData): Data with observations to compare. column (str): Column in `adata.obs` that contains the groups to compare. baseline (str): Reference group. - groups_to_compare (str | Sequence[str]): Groups to compare against the baseline. If None, all other groups are compared. + groups_to_compare (str | Sequence[str]): Groups to compare against the baseline. If None, all other groups + are compared. paired_by (str | None): Column in `adata.obs` to use for pairing. If None, an unpaired test is performed. mask (str | None): Mask to apply to the data. layer (str | None): Layer to use for the comparison. n_permutations (int): Number of permutations to perform if a permutation test is used. - permutation_test_statistic (type[SimpleComparisonBase] | None): Test to use after permutation if a permutation test is used. + permutation_test_statistic (type[SimpleComparisonBase] | None): Test to use after permutation if a + permutation test is used. fit_kwargs (Mapping): Not used for simple tests. test_kwargs (Mapping): Additional kwargs passed to the test function. n_jobs (int): Number of parallel jobs to use. @@ -251,13 +253,16 @@ def compare_means(x0, x1, paired): x0: Array with baseline values. x1: Array with values to compare. paired: Whether to perform a paired test - test_statistic: The class or function to generate the test statistic from permuted data. If a function is passed, it must have the signature `test_statistic(x0, x1, paired[, axis], **kwargs)`. If it accepts the parameter axis, vectorization will be used. + test_statistic: The class or function to generate the test statistic from permuted data. If a function is + passed, it must have the signature `test_statistic(x0, x1, paired[, axis], **kwargs)`. If it accepts the + parameter axis, vectorization will be used. n_permutations: Number of permutations to perform. **kwargs: kwargs passed to the permutation test function, not the test function after permutation. """ if test_statistic is PermutationTest: raise ValueError( - "The `test_statistic` argument cannot be `PermutationTest`. Use a base test like `WilcoxonTest` or `TTest`." + "The `test_statistic` argument cannot be `PermutationTest`. Use a base test like `WilcoxonTest` or" + " `TTest`." ) vectorized = hasattr(test_statistic, "_test") or "axis" in signature(test_statistic).parameters From 442b6037cacce142a77bb5dfd4e09dc43e0911c5 Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Wed, 19 Mar 2025 11:07:03 +0100 Subject: [PATCH 10/11] Remove parallelization and return statistic and p-value everywhere --- .../_simple_tests.py | 30 +++++++++---------- .../test_simple_tests.py | 6 ++-- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index 9bda95a9..b2306726 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -5,7 +5,6 @@ from collections.abc import Callable, Mapping, Sequence from inspect import signature from types import MappingProxyType -from typing import Any import numpy as np import pandas as pd @@ -83,7 +82,7 @@ def _compare_single_group( test_result = self._test(tmp_x0, tmp_x1, paired, **kwargs) mean_x0 = np.mean(tmp_x0) mean_x1 = np.mean(tmp_x1) - res.append({"variable": var, "log_fc": np.log2(mean_x1) - np.log2(mean_x0)}, **test_result) + res.append({"variable": var, "log_fc": np.log2(mean_x1) - np.log2(mean_x0), **test_result}) return pd.DataFrame(res).sort_values("p_value") @classmethod @@ -101,7 +100,6 @@ def compare_groups( permutation_test_statistic: type["SimpleComparisonBase"] | None = None, fit_kwargs: Mapping = MappingProxyType({}), test_kwargs: Mapping = MappingProxyType({}), - n_jobs: int = -1, ) -> DataFrame: """Perform a comparison between groups. @@ -119,7 +117,6 @@ def compare_groups( permutation test is used. fit_kwargs (Mapping): Not used for simple tests. test_kwargs (Mapping): Additional kwargs passed to the test function. - n_jobs (int): Number of parallel jobs to use. """ if len(fit_kwargs): warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) @@ -207,18 +204,17 @@ def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, f class PermutationTest(SimpleComparisonBase): """Perform a permutation test. - The permutation test relies on another test (e.g. WilcoxonTest) to perform the actual comparison - based on permuted data. The p-value is then calculated based on the distribution of the test - statistic under the null hypothesis. + The permutation test relies on another test statistic (e.g. t-statistic or your own) to obtain a p-value through + random permutations of the data and repeated generation of the test statistic. - For paired tests, each paired observation is permuted together and distributed randomly between - the two groups. For unpaired tests, all observations are permuted independently. + For paired tests, each paired observation is permuted together and distributed randomly between the two groups. For + unpaired tests, all observations are permuted independently. - The null hypothesis for the unpaired test is that all observations come from the same underlying - distribution and have been randomly assigned to one of the samples. + The null hypothesis for the unpaired test is that all observations come from the same underlying distribution and + have been randomly assigned to one of the samples. - The null hypothesis for the paired permutation test is that the observations within each pair are - drawn from the same underlying distribution and that their assignment to a sample is random. + The null hypothesis for the paired permutation test is that the observations within each pair are drawn from the + same underlying distribution and that their assignment to a sample is random. """ @staticmethod @@ -271,14 +267,16 @@ def call_test(data_baseline, data_comparison, axis: int | None = None, **kwargs) """Perform the actual test.""" if not hasattr(test_statistic, "_test"): if vectorized: - return test_statistic(data_baseline, data_comparison, paired, axis, **kwargs) + return test_statistic(data_baseline, data_comparison, paired=paired, axis=axis, **kwargs)[ + "statistic" + ] - return test_statistic(data_baseline, data_comparison, paired, **kwargs) + return test_statistic(data_baseline, data_comparison, paired=paired, **kwargs)["statistic"] if vectorized: kwargs.update({"axis": axis}) - return test_statistic._test(data_baseline, data_comparison, paired, **kwargs) + return test_statistic._test(data_baseline, data_comparison, paired, **kwargs)["statistic"] test_result = scipy.stats.permutation_test( [x0, x1], diff --git a/tests/tools/_differential_gene_expression/test_simple_tests.py b/tests/tools/_differential_gene_expression/test_simple_tests.py index ec7771ea..c9693eff 100644 --- a/tests/tools/_differential_gene_expression/test_simple_tests.py +++ b/tests/tools/_differential_gene_expression/test_simple_tests.py @@ -81,15 +81,15 @@ def test_permutation(test_adata_minimal, paired_by, expected): Reference values have been computed in R using wilcox.test """ - for permutation_test in [TTest, WilcoxonTest]: + for statistic in [TTest, WilcoxonTest]: res_df = PermutationTest.compare_groups( adata=test_adata_minimal, column="condition", baseline="A", groups_to_compare="B", paired_by=paired_by, - n_permutations=200, - permutation_test=permutation_test, + n_permutations=1000, + permutation_test_statistic=statistic, test_kwargs={"rng": 0}, ) assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame" From 8ae69ce3a7b4d017200f072c2e44f84ba179099c Mon Sep 17 00:00:00 2001 From: Malte Benedikt Kuehl Date: Mon, 7 Apr 2025 10:27:47 +0200 Subject: [PATCH 11/11] Fix docstring and examples of permutation test --- .../_simple_tests.py | 53 +++++++++---------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/pertpy/tools/_differential_gene_expression/_simple_tests.py b/pertpy/tools/_differential_gene_expression/_simple_tests.py index b2306726..26e8fab8 100644 --- a/pertpy/tools/_differential_gene_expression/_simple_tests.py +++ b/pertpy/tools/_differential_gene_expression/_simple_tests.py @@ -104,19 +104,19 @@ def compare_groups( """Perform a comparison between groups. Args: - adata (AnnData): Data with observations to compare. - column (str): Column in `adata.obs` that contains the groups to compare. - baseline (str): Reference group. - groups_to_compare (str | Sequence[str]): Groups to compare against the baseline. If None, all other groups + adata: Data with observations to compare. + column: Column in `adata.obs` that contains the groups to compare. + baseline: Reference group. + groups_to_compare: Groups to compare against the baseline. If None, all other groups are compared. - paired_by (str | None): Column in `adata.obs` to use for pairing. If None, an unpaired test is performed. - mask (str | None): Mask to apply to the data. - layer (str | None): Layer to use for the comparison. - n_permutations (int): Number of permutations to perform if a permutation test is used. - permutation_test_statistic (type[SimpleComparisonBase] | None): Test to use after permutation if a - permutation test is used. - fit_kwargs (Mapping): Not used for simple tests. - test_kwargs (Mapping): Additional kwargs passed to the test function. + paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed. + mask: Mask to apply to the data. + layer: Layer to use for the comparison. + n_permutations: Number of permutations to perform if a permutation test is used. + permutation_test_statistic: The statistic to use if performing a permutation test. If None, the default + t-statistic from `TTest` is used. + fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify. + test_kwargs: Additional kwargs passed to the test function. """ if len(fit_kwargs): warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) @@ -230,21 +230,6 @@ def _test( This function relies on another test (e.g. WilcoxonTest) to generate a test statistic for each permutation. - .. code-block:: python - from pertpy.tools import PermutationTest, WilcoxonTest - - # Using rank-sum statistic - p_value = PermutationTest._test(x0, x1, paired=True, test=WilcoxonTest, n_permutations=1000, rng=0) - - - # Using a custom test statistic - def compare_means(x0, x1, paired): - # paired logic not implemented here - return np.mean(x1) - np.mean(x0) - - - p_value = PermutationTest._test(x0, x1, paired=False, test=compare_means, n_permutations=1000, rng=0) - Args: x0: Array with baseline values. x1: Array with values to compare. @@ -254,11 +239,21 @@ def compare_means(x0, x1, paired): parameter axis, vectorization will be used. n_permutations: Number of permutations to perform. **kwargs: kwargs passed to the permutation test function, not the test function after permutation. + + Examples: + You can use the `PermutationTest` class to perform a permutation test with a custom test statistic or an + existing test statistic like `TTest`. The test statistic must be a class that implements the `_test` method + or a function that takes the arguments `x0`, `x1`, `paired` and `**kwargs`. + + >>> from pertpy.tools import PermutationTest, TTest + >>> # Perform a permutation test with a t-statistic + >>> p_value = PermutationTest._test(x0, x1, paired=True, test=TTest, n_permutations=1000, rng=0) + >>> # Perform a permutation test with a custom test statistic + >>> p_value = PermutationTest._test(x0, x1, paired=False, test=your_custom_test_statistic) """ if test_statistic is PermutationTest: raise ValueError( - "The `test_statistic` argument cannot be `PermutationTest`. Use a base test like `WilcoxonTest` or" - " `TTest`." + "The `test_statistic` argument cannot be `PermutationTest`. Use a base test like `TTest` or a custom test." ) vectorized = hasattr(test_statistic, "_test") or "axis" in signature(test_statistic).parameters