Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/cellflow/model/_cellflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, adata: ad.AnnData, solver: Literal["otfm", "genot"] = "otfm")
self._dataloader: TrainSampler | None = None
self._trainer: CellFlowTrainer | None = None
self._validation_data: dict[str, ValidationData] = {}
self._validation_adata: dict[str, ad.Anndata] = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make these attributes of the callbacks directly?

self._solver: _otfm.OTFlowMatching | _genot.GENOT | None = None
self._condition_dim: int | None = None
self._vf: _velocity_field.ConditionalVelocityField | _velocity_field.GENOTConditionalVelocityField | None = None
Expand Down Expand Up @@ -225,6 +226,7 @@ def prepare_validation_data(
n_conditions_on_log_iteration=n_conditions_on_log_iteration,
n_conditions_on_train_end=n_conditions_on_train_end,
)
self._validation_adata[name] = adata
self._validation_data[name] = val_data

def prepare_model(
Expand Down Expand Up @@ -498,7 +500,8 @@ def prepare_model(
)
else:
raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(self.solver)}")
self._trainer = CellFlowTrainer(solver=self.solver) # type: ignore[arg-type]
validation_adata = self._validation_adata or {}
self._trainer = CellFlowTrainer(solver=self.solver, validation_adata=validation_adata) # type: ignore[arg-type]

def train(
self,
Expand Down
2 changes: 2 additions & 0 deletions src/cellflow/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
LoggingCallback,
Metrics,
PCADecodedMetrics,
PCADecodedMetrics2,
VAEDecodedMetrics,
WandbLogger,
)
Expand All @@ -19,6 +20,7 @@
"WandbLogger",
"CallbackRunner",
"PCADecodedMetrics",
"PCADecodedMetrics2",
"PCADecoder",
"VAEDecodedMetrics",
]
100 changes: 100 additions & 0 deletions src/cellflow/training/_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.tree as jt
import jax.tree_util as jtu
import numpy as np
import scipy

from cellflow._types import ArrayLike
from cellflow.metrics._metrics import (
Expand All @@ -14,6 +15,7 @@
compute_scalar_mmd,
compute_sinkhorn_div,
)
from cellflow.solvers import _genot, _otfm

__all__ = [
"BaseCallback",
Expand All @@ -23,6 +25,7 @@
"WandbLogger",
"CallbackRunner",
"PCADecodedMetrics",
"PCADecodedMetrics2",
"VAEDecodedMetrics",
]

Expand Down Expand Up @@ -266,6 +269,103 @@ def on_log_iteration(
return metrics


class PCADecodedMetrics2(Metrics):
"""Callback to compute metrics on true validation data during training

Parameters
----------
ref_adata
An :class:`~anndata.AnnData` object with the reference data containing
``adata.varm["X_mean"]`` and ``adata.varm["PCs"]``.
metrics
List of metrics to compute. Supported metrics are ``"r_squared"``, ``"mmd"``,
``"sinkhorn_div"``, and ``"e_distance"``.
metric_aggregations
List of aggregation functions to use for each metric. Supported aggregations are ``"mean"``
and ``"median"``.
condition_id_key
Key in :attr:`~anndata.AnnData.obs` that defines the condition id.
layer
Key in :attr:`~anndata.AnnData.layers` from which to get the counts.
If :obj:`None`, use :attr:`~anndata.AnnData.X`.
log_prefix
Prefix to add to the log keys.
"""

def __init__(
self,
ref_adata: ad.AnnData,
metrics: list[Literal["r_squared", "mmd", "sinkhorn_div", "e_distance"]],
metric_aggregations: list[Literal["mean", "median"]] = None,
condition_id_key: str = "condition",
layers: str | None = None,
log_prefix: str = "pca_decoded_2_",
):
super().__init__(metrics, metric_aggregations)
self.pcs = ref_adata.varm["PCs"]
self.means = ref_adata.varm["X_mean"]
self.reconstruct_data = lambda x: x @ np.transpose(self.pcs) + np.transpose(self.means)
self.condition_id_key = condition_id_key
self.layers = layers
self.log_prefix = log_prefix

def add_validation_adata(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be part of __init__, can't it?

self,
validation_adata: dict[str, ad.AnnData],
) -> None:
self.validation_adata = validation_adata

def on_log_iteration(
self,
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _genot.GENOT | _otfm.OTFlowMatching,
) -> dict[str, float]:
"""Called at each validation/log iteration to reconstruct the data and compute metrics on the reconstruction

Parameters
----------
valid_source_data
Source data in nested dictionary format with same keys as ``valid_true_data``
valid_true_data
Validation data in nested dictionary format with same keys as ``valid_pred_data``
valid_pred_data
Predicted data in nested dictionary format with same keys as ``valid_true_data``
solver
:class:`~cellflow.solvers.OTFlowMatching` solver or :class:`~cellflow.solvers.GENOT`
solver with a conditional velocity field.
"""
true_counts = {}
for name in self.validation_adata.keys():
true_counts[name] = {}
conditions_adata = set(self.validation_adata[name].obs[self.condition_id_key].unique())
conditions_pred = valid_pred_data[name].keys()
for cond in conditions_adata & conditions_pred:
condition_mask = self.validation_adata[name].obs[self.condition_id_key] == cond
counts = (
self.validation_adata[name][condition_mask].X
if self.layers is None
else self.validation_adata[name][condition_mask].layers[self.layers]
)
true_counts[name][cond] = counts.toarray() if scipy.sparse.issparse(counts) else counts

predicted_data_decoded = jtu.tree_map(self.reconstruct_data, valid_pred_data)

metrics = super().on_log_iteration(true_counts, predicted_data_decoded)
metrics = {f"{self.log_prefix}{k}": v for k, v in metrics.items()}
return metrics

def on_train_end(
self,
valid_source_data: dict[str, dict[str, ArrayLike]],
valid_true_data: dict[str, dict[str, ArrayLike]],
valid_pred_data: dict[str, dict[str, ArrayLike]],
solver: _genot.GENOT | _otfm.OTFlowMatching,
) -> dict[str, float]:
return self.on_log_iteration(valid_source_data, valid_true_data, valid_pred_data, solver)


class VAEDecodedMetrics(Metrics):
"""Callback to compute metrics on decoded validation data during training

Expand Down
9 changes: 8 additions & 1 deletion src/cellflow/training/_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from collections.abc import Sequence
from typing import Any, Literal

import anndata as ad
import jax
import numpy as np
from numpy.typing import ArrayLike
from tqdm import tqdm

from cellflow.data._dataloader import TrainSampler, ValidationSampler
from cellflow.solvers import _genot, _otfm
from cellflow.training._callbacks import BaseCallback, CallbackRunner
from cellflow.training._callbacks import BaseCallback, CallbackRunner, PCADecodedMetrics2


class CellFlowTrainer:
Expand All @@ -31,12 +32,14 @@ class CellFlowTrainer:
def __init__(
self,
solver: _otfm.OTFlowMatching | _genot.GENOT,
validation_adata: dict[str, ad.AnnData],
seed: int = 0,
):
if not isinstance(solver, (_otfm.OTFlowMatching | _genot.GENOT)):
raise NotImplementedError(f"Solver must be an instance of OTFlowMatching or GENOT, got {type(solver)}")

self.solver = solver
self.validation_adata = validation_adata
self.rng_subsampling = np.random.default_rng(seed)
self.training_logs: dict[str, Any] = {}

Expand Down Expand Up @@ -103,6 +106,10 @@ def train(
self.training_logs = {"loss": []}
rng = jax.random.PRNGKey(0)

for callback in callbacks:
if isinstance(callback, PCADecodedMetrics2):
callback.add_validation_adata(self.validation_adata)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this would be needed if validation_adata was an attribute of the callback, would it?


# Initiate callbacks
valid_loaders = valid_loaders or {}
crun = CallbackRunner(
Expand Down
29 changes: 29 additions & 0 deletions tests/trainer/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anndata as ad
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import pytest


Expand All @@ -18,6 +19,34 @@ def test_pca_reconstruction(self, adata_pca: ad.AnnData, metrics):
assert reconstruction.shape == adata_pca.X.shape
assert jnp.allclose(reconstruction, adata_pca.layers["counts"])

@pytest.mark.parametrize("sparse_matrix", [True, False])
@pytest.mark.parametrize("layers", [None, "test"])
def test_pca_decoded_2(self, adata_pca: ad.AnnData, sparse_matrix, layers):
from cellflow.solvers import OTFlowMatching
from cellflow.training import PCADecodedMetrics2

adata_gt = adata_pca.copy()
adata_gt.obs["condition"] = np.random.choice(["A", "B"], size=adata_pca.shape[0])
if not sparse_matrix:
adata_gt.X = adata_gt.X.toarray()
if layers is not None:
adata_gt.layers[layers] = adata_gt.X.copy()

decoded_metrics_callback = PCADecodedMetrics2(
ref_adata=adata_pca, metrics=["r_squared"], condition_id_key="condition", layers=layers
)

callbacks = [decoded_metrics_callback]
for e in callbacks:
if isinstance(e, PCADecodedMetrics2):
e.add_validation_adata({"test": adata_gt})

valid_pred_data = {"test": {"A": np.random.random((2, 10)), "B": np.random.random((2, 10))}}

res = decoded_metrics_callback.on_log_iteration({}, {}, valid_pred_data, OTFlowMatching)
assert "pca_decoded_2_test_r_squared_mean" in res
assert isinstance(res["pca_decoded_2_test_r_squared_mean"], float)

@pytest.mark.parametrize("metrics", [["r_squared"]])
def test_vae_reconstruction(self, metrics):
from scvi.data import synthetic_iid
Expand Down
Loading