Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/cellflow/model/_cellflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot"] = "otfm")
self._dataloader: TrainSampler | None = None
self._trainer: CellFlowTrainer | None = None
self._validation_data: dict[str, ValidationData] = {}
self._validation_adata: dict[str, ad.Anndata] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make these attributes of the callbacks directly?

self._solver: _otfm.OTFlowMatching | _genot.GENOT | None = None
self._condition_dim: int | None = None
self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | None = None
Expand Down Expand Up @@ -225,6 +226,7 @@ def prepare_validation_data(
n_conditions_on_log_iteration=n_conditions_on_log_iteration,
n_conditions_on_train_end=n_conditions_on_train_end,
)
self._validation_adata[name] = adata
self._validation_data[name] = val_data

def prepare_model(
Expand Down Expand Up @@ -498,7 +500,8 @@ def prepare_model(
)
else:
raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}")
self._trainer = CellFlowTrainer(solver=self.solver) # type: ignore[arg-type]
validation_adata = self._validation_adata or {}
self._trainer = CellFlowTrainer(solver=self.solver, validation_adata=validation_adata) # type: ignore[arg-type]

def train(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/cellflow/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LoggingCallback,
Metrics,
PCADecodedMetrics,
PCADecodedMetrics2,
VAEDecodedMetrics,
WandbLogger,
)
Expand All @@ -19,6 +20,7 @@
"WandbLogger",
"CallbackRunner",
"PCADecodedMetrics",
"PCADecodedMetrics2",
"PCADecoder",
"VAEDecodedMetrics",
]
90 changes: 90 additions & 0 deletions src/cellflow/training/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
compute_scalar_mmd,
compute_sinkhorn_div,
)
from cellflow.solvers import _genot, _otfm

__all__ = [
"BaseCallback",
Expand All @@ -23,6 +24,7 @@
"WandbLogger",
"CallbackRunner",
"PCADecodedMetrics",
"PCADecodedMetrics2",
"VAEDecodedMetrics",
]

Expand Down Expand Up @@ -266,6 +268,94 @@ def on_log_iteration(
return metrics


class PCADecodedMetrics2(Metrics):
"""Callback to compute metrics on true validation data during training

Parameters
----------
ref_adata
An :class:`~anndata.AnnData` object with the reference data containing
``adata.varm["X_mean"]`` and ``adata.varm["PCs"]``.
metrics
List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``,
``"sinkhorn_div"``, and ``"e_distance"``.
metric_aggregations
List of aggregation functions to use for each metric. Supported aggregations are ``"mean"``
and ``"median"``.
condition_id_key
Key in :attr:`~anndata.AnnData.obs` that defines the condition id.
log_prefix
Prefix to add to the log keys.
"""

def __init__(
self,
ref_adata: ad.AnnData,
metrics: list[Literal["r_squared", "mmd", "sinkhorn_div", "e_distance"]],
metric_aggregations: list[Literal["mean", "median"]] = None,
condition_id_key: str = "condition",
log_prefix: str = "pca_decoded_2_",
):
super().__init__(metrics, metric_aggregations)
self.pcs = ref_adata.varm["PCs"]
self.means = ref_adata.varm["X_mean"]
self.reconstruct_data = lambda x: x @ np.transpose(self.pcs) + np.transpose(self.means)
self.condition_id_key = condition_id_key
self.log_prefix = log_prefix

def add_validation_adata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be part of __init__, can't it?

self,
validation_adata: dict[str, ad.AnnData],
) -> None:
self.validation_adata = validation_adata

def on_log_iteration(
self,
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _genot.GENOT | _otfm.OTFlowMatching,
) -> dict[str, float]:
"""Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction

Parameters
----------
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.
"""
true_counts = {}
for name in self.validation_adata.keys():
true_counts[name] = {}
conditions_adata = set(self.validation_adata[name].obs[self.condition_id_key].unique())
conditions_pred = valid_pred_data[name].keys()
for cond in conditions_adata & conditions_pred:
true_counts[name][cond] = self.validation_adata[name][
self.validation_adata[name].obs[self.condition_id_key] == cond
].X.toarray()
Copy link
Collaborator Author

@LeonStadelmann LeonStadelmann May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we potentially go OOM here when using real data? I was not sure @MUCDK

Copy link
Collaborator

@MUCDK MUCDK May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be fine, because we first subset the adata before densifying.

two things here:

  1. provide the option to extract any adata.X or any adata.layers[key], similar how we e.g. do it in moscot
  2. bear in mind that adata.X can be both sparse and dense


predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data)

metrics = super().on_log_iteration(true_counts, predicted_data_decoded)
metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()}
return metrics

def on_train_end(
self,
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _genot.GENOT | _otfm.OTFlowMatching,
) -> dict[str, float]:
return self.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver)


class VAEDecodedMetrics(Metrics):
"""Callback to compute metrics on decoded validation data during training

Expand Down
9 changes: 8 additions & 1 deletion src/cellflow/training/_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from collections.abc import Sequence
from typing import Any, Literal

import anndata as ad
import jax
import numpy as np
from numpy.typing import ArrayLike
from tqdm import tqdm

from cellflow.data._dataloader import TrainSampler, ValidationSampler
from cellflow.solvers import _genot, _otfm
from cellflow.training._callbacks import BaseCallback, CallbackRunner
from cellflow.training._callbacks import BaseCallback, CallbackRunner, PCADecodedMetrics2


class CellFlowTrainer:
Expand All @@ -31,12 +32,14 @@ class CellFlowTrainer:
def __init__(
self,
solver: _otfm.OTFlowMatching | _genot.GENOT,
validation_adata: dict[str, ad.AnnData],
seed: int = 0,
):
if not isinstance(solver, (_otfm.OTFlowMatching | _genot.GENOT)):
raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(solver)}")

self.solver = solver
self.validation_adata = validation_adata
self.rng_subsampling = np.random.default_rng(seed)
self.training_logs: dict[str, Any] = {}

Expand Down Expand Up @@ -103,6 +106,10 @@ def train(
self.training_logs = {"loss": []}
rng = jax.random.PRNGKey(0)

for callback in callbacks:
if isinstance(callback, PCADecodedMetrics2):
callback.add_validation_adata(self.validation_adata)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this would be needed if validation_adata was an attribute of the callback, would it?


# Initiate callbacks
valid_loaders = valid_loaders or {}
crun = CallbackRunner(
Expand Down
23 changes: 23 additions & 0 deletions tests/trainer/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anndata as ad
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import pytest


Expand All @@ -18,6 +19,28 @@ def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics):
assert reconstruction.shape == adata_pca.X.shape
assert jnp.allclose(reconstruction, adata_pca.layers["counts"])

def test_pca_decoded_2(self, adata_pca: ad.AnnData):
from cellflow.solvers import OTFlowMatching
from cellflow.training import PCADecodedMetrics2

adata_gt = adata_pca.copy()
adata_gt.obs["condition"] = np.random.choice(["A", "B"], size=adata_pca.shape[0])

decoded_metrics_callback = PCADecodedMetrics2(
ref_adata=adata_pca, metrics=["r_squared"], condition_id_key="condition"
)

callbacks = [decoded_metrics_callback]
for e in callbacks:
if isinstance(e, PCADecodedMetrics2):
e.add_validation_adata({"test": adata_gt})

valid_pred_data = {"test": {"A": np.random.random((2, 10)), "B": np.random.random((2, 10))}}

res = decoded_metrics_callback.on_log_iteration({}, {}, valid_pred_data, OTFlowMatching)
assert "pca_decoded_2_test_r_squared_mean" in res
assert isinstance(res["pca_decoded_2_test_r_squared_mean"], float)

@pytest.mark.parametrize("metrics", [["r_squared"]])
def test_vae_reconstruction(self, metrics):
from scvi.data import synthetic_iid
Expand Down
Loading