Skip to content

Solve TODOs and verify DIALOGUE 1 #714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .RData
Binary file not shown.
2 changes: 2 additions & 0 deletions .Rhistory
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
devtools::install_github(repo = "https://github.com/livnatje/DIALOGUE")
q()
126 changes: 93 additions & 33 deletions pertpy/tools/_dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, sample_id: str, celltype_key: str, n_counts_key: str, n_mpcs:
self.n_mcps = n_mpcs

def _get_pseudobulks(
self, adata: AnnData, groupby: str, strategy: Literal["median", "mean"] = "median"
self, adata: AnnData, groupby: str, strategy: Literal["median", "mean"] = "mean"
) -> pd.DataFrame:
"""Return cell-averaged data by groupby.

Expand Down Expand Up @@ -83,45 +83,73 @@ def _get_pseudobulks(
return pseudobulk

def _pseudobulk_feature_space(
self, adata: AnnData, groupby: str, n_components: int = 50, feature_space_key: str = "X_pca"
self,
adata: AnnData,
groupby: str,
n_components: int = 50,
feature_space_key: str = "X_pca",
agg_func=np.median, # default to np.median; user can supply np.mean, etc.
) -> pd.DataFrame:
"""Return Cell-averaged components from a passed feature space.

TODO: consider merging with `get_pseudobulks`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not remove these TODOs since they haven't been done and probably still should be done.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, the code looks very different now tho

TODO: DIALOGUE recommends running PCA on each cell type separately before running PMD - this should be implemented as an option here.
"""Return cell-averaged components from a passed feature space using a custom aggregation function.

Args:
groupby: The key to groupby for pseudobulks.
n_components: The number of components to use.
feature_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
groupby: The key to group by for pseudobulks (e.g., a sample identifier).
n_components: The number of components (features) to use.
feature_space_key: The key in adata.obsm for the feature space (e.g., "X_pca", "X_umap").
agg_func: The aggregation function to use (e.g., np.median, np.mean). It should accept an `axis` argument.

Returns:
A pseudobulk DataFrame of the averaged components.
A pseudobulk DataFrame of the aggregated components, with samples as rows and components as columns.
"""
aggr = {}
# Loop over each category in the specified groupby column (assumed to be categorical)
for category in adata.obs.loc[:, groupby].cat.categories:
temp = adata.obs.loc[:, groupby] == category
aggr[category] = adata[temp].obsm[feature_space_key][:, :n_components].mean(axis=0)
aggr = pd.DataFrame(aggr)
return aggr

def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True) -> np.ndarray:
"""Row-wise mean center and scale by the standard deviation.
# Apply the user-provided aggregation function along axis 0 (averaging across cells)
aggr[category] = agg_func(adata[temp].obsm[feature_space_key][:, :n_components], axis=0)
# Create a DataFrame; keys become columns
aggr_df = pd.DataFrame(aggr)
# Transpose so that rows correspond to samples and columns to features
return aggr_df.T
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that you've moved the transpose from where ct_preprocess is defined to here. However, this is going to cause a problem when _pseudobulk_feature_space isn't used, as when agg_feature is False. Currently, it's set to True by default and never exposed to the user, but this is because of the incomplete implementation of the TODO on the previous line 660: https://github.com/scverse/pertpy/blob/main/pertpy/tools/_dialogue.py#L660 which I understand is very vague, but I'm happy to hop on a call to explain. Can you change this so that either _pseudobulk_feature_space and _get_pseudobulks are merged as is suggested in the original TODO, or revert it back to the original implementation so that the stub still exists?


def _scale_data(self, pseudobulks: pd.DataFrame, normalize: bool = True, cap: float = 0.01) -> np.ndarray:
"""Row-wise mean center and scale by the standard deviation,
and then cap extreme values based on quantiles.

This mimics the following R function (excluding row subsetting):

f <- function(X1){
if(param$center.flag){
X1 <- center.matrix(X1, dim = 2, sd.flag = TRUE)
X1 <- cap.mat(X1, cap = 0.01, MARGIN = 2)
}
X1 <- X1[samplesU, ]
return(X1)
}

Args:
pseudobulks: The pseudobulk PCA components.
normalize: Whether to mimic DIALOGUE behavior or not.
pseudobulks: The pseudobulk PCA components as a DataFrame (samples as rows, features as columns).
normalize: Whether to perform centering, scaling, and capping.
cap: The quantile threshold for capping. For example, cap=0.01 means that for each column, values
above the 99th percentile are set to the 99th percentile, and values below the 1st percentile are set to the 1st percentile.

Returns:
The scaled count matrix.
The processed (scaled and capped) matrix as a NumPy array.
"""
# TODO: the `scale` function we implemented to match the R `scale` fn should already contain this functionality.
# DIALOGUE doesn't scale the data before passing to multicca, unlike what is recommended by sparsecca.
# However, performing this scaling _does_ increase overall correlation of the end result
if normalize:
return pseudobulks.to_numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it might seem strange that if normalize=True it's not actually normalized, but this was done because of the above comment - the R implementation has normalize=True and does not scale, and so we wanted this part to match, when setting all the hyperparameters to be the same across the two versions. Did you check that quantile-based capping brings the result closer to the R implementation? If so, then it's fine to flip the if-else logic into the correct orientation!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# Center and scale (column-wise: subtract the column mean and divide by the column std)
scaled = (pseudobulks - pseudobulks.mean()) / pseudobulks.std()

# Apply quantile-based capping column-wise.
capped = scaled.copy()
for col in scaled.columns:
lower = scaled[col].quantile(cap) # lower quantile (e.g., 1st percentile)
upper = scaled[col].quantile(1 - cap) # upper quantile (e.g., 99th percentile)
capped[col] = scaled[col].clip(lower=lower, upper=upper)

return capped.to_numpy()
else:
return ((pseudobulks - pseudobulks.mean()) / pseudobulks.std()).to_numpy()
return pseudobulks.to_numpy()

def _concat_adata_mcp_scores(
self, adata: AnnData, ct_subs: dict[str, AnnData], mcp_scores: dict[str, np.ndarray], celltype_key: str
Expand Down Expand Up @@ -560,31 +588,61 @@ def _load(
ct_order: list[str],
agg_feature: bool = True,
normalize: bool = True,
subset_common: bool = True, # new optional parameter
) -> tuple[list, dict]:
"""Separates cell into AnnDatas by celltype_key and creates the multifactor PMD input.
"""Separates cells into AnnDatas by celltype_key and creates the multifactor PMD input.

Mimics DIALOGUE's `make.cell.types` and the pre-processing that occurs in DIALOGUE1.

Args:
adata: AnnData object generate celltype objects for
ct_order: The order of cell types
adata: AnnData object to generate celltype objects for.
ct_order: The order of cell types.
agg_feature: Whether to aggregate pseudobulks with some embeddings or not.
normalize: Whether to mimic DIALOGUE behavior or not.
subset_common: If True, restrict output to common samples across cell types.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, this functionality has been missing for a while - in fact, dialoguepy fails with an obscure error without it and right now the user just has to know to do this beforehand, so this is great. Given that it's mandatory, it probably shouldn't be a parameter at all but instead just run by default with a loud warning.


Returns:
A celltype_label:array dictionary.
A tuple with:
- mcca_in: A list of pseudobulk matrices (one per cell type), with rows corresponding to sample IDs (if subset_common is True).
- ct_subs: A dictionary mapping each cell type to its corresponding AnnData subset (restricted to common samples if subset_common is True).
"""
# 1. Split the AnnData into cell-type–specific subsets.
ct_subs = {ct: adata[adata.obs[self.celltype_key] == ct].copy() for ct in ct_order}

# 2. Choose the aggregation function based on the flag.
fn = self._pseudobulk_feature_space if agg_feature else self._get_pseudobulks

# 3. Compute pseudobulk features for each cell type.
# Here, fn should return a DataFrame with sample IDs as the row index.
ct_aggr = {ct: fn(ad, self.sample_id) for ct, ad in ct_subs.items()} # type: ignore

# TODO: implement check (as in https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L114-L119)
# that there are at least 5 share samples here
# 4. Apply scaling/normalization to the aggregated data.
# We wrap the output back in a DataFrame to preserve the sample IDs.
ct_scaled = {
ct: pd.DataFrame(self._scale_data(df, normalize=normalize), index=df.index, columns=df.columns)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ultra minor quip, but _scale_data should simply preserve the ordering of the df index and columns within the function, so that you don't need the pd.DataFrame wrap here. This is also the generally expected standard when it comes to calling a numpy function on a pandas DataFrame.

for ct, df in ct_aggr.items()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for changing this to for ct, df instead of for ct, ad which was super misleading before

}

# TODO: https://github.com/livnatje/DIALOGUE/blob/55da9be0a9bf2fcd360d9e11f63e30d041ec4318/R/DIALOGUE.main.R#L121-L131
ct_preprocess = {ct: self._scale_data(ad, normalize=normalize).T for ct, ad in ct_aggr.items()}
if subset_common:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, I would try to move this section up to right after ct_subs is instantiated so that, in case there aren't the requisite samples, the code fails as quickly as possible. Furthermore, you won't have to reprocess all the various ct_X variables so that they match.

# 5. Determine the set of common samples across all cell types (using the scaled data).
common_samples = set(ct_scaled[ct_order[0]].index)
for ct in ct_order[1:]:
common_samples = common_samples.intersection(set(ct_scaled[ct].index))
common_samples_sorted = sorted(common_samples)

mcca_in = [ct_preprocess[ct] for ct in ct_order]
# Check if there are at least 5 common samples.
if len(common_samples_sorted) < 5:
raise ValueError("Cannot run DIALOGUE with less than 5 common samples across cell types.")

# 6. Subset each scaled pseudobulk DataFrame to only the common samples.
ct_scaled = {ct: df.loc[common_samples_sorted] for ct, df in ct_scaled.items()}

# 7. Also, restrict each cell-type AnnData to cells belonging to one of the common samples.
for ct in ct_subs:
ct_subs[ct] = ct_subs[ct][ct_subs[ct].obs[self.sample_id].isin(common_samples_sorted)].copy()

# 8. Order the preprocessed pseudobulk matrices as a list in the order specified by ct_order.
mcca_in = [ct_scaled[ct] for ct in ct_order]

return mcca_in, ct_subs

Expand Down Expand Up @@ -633,6 +691,8 @@ def calculate_multifactor_PMD(

mcca_in, ct_subs = self._load(adata, ct_order=cell_types, agg_feature=agg_feature, normalize=normalize)

mcca_in = [df.to_numpy() if hasattr(df, "to_numpy") else df for df in mcca_in]

n_samples = mcca_in[0].shape[1]
if penalties is None:
try:
Expand Down
Loading