diff --git a/src/methods/drvi/config.vsh.yaml b/src/methods/drvi/config.vsh.yaml new file mode 100644 index 00000000..4233224e --- /dev/null +++ b/src/methods/drvi/config.vsh.yaml @@ -0,0 +1,89 @@ +# The API specifies which type of component this is. +# It contains specifications for: +# - The input/output files +# - Common parameters +# - A unit test +__merge__: ../../api/comp_method.yaml + +# A unique identifier for your component (required). +# Can contain only lowercase letters or underscores. +name: drvi +# A relatively short label, used when rendering visualisations (required) +label: DRVI +# A one sentence summary of how this method works (required). Used when +# rendering summary tables. +summary: "DrVI is an unsupervised generative model capable of learning non-linear interpretable disentangled latent representations from single-cell count data." +# A multi-line description of how this component works (required). Used +# when rendering reference documentation. +description: | + Disentangled Representation Variational Inference (DRVI) is an unsupervised deep generative model designed for integrating single-cell RNA sequencing (scRNA-seq) data across different batches. + It extends the variational autoencoder (VAE) framework by learning a latent representation that captures biological variation while disentangling and correcting for batch effects. + DRVI conditions both the encoder and decoder on batch covariates, allowing it to explicitly model and mitigate batch-specific variations during training. + By incorporating a KL-divergence regularization term, it balances data reconstruction with latent space structure, resulting in a unified embedding where similar cells cluster together regardless of batch. +references: + doi: + - 10.1101/2024.11.06.622266 +# bibtex: +# - | +# @article{foo, +# title={Foo}, +# author={Bar}, +# journal={Baz}, +# year={2024} +# } +links: + # URL to the documentation for this method (required). + documentation: https://drvi.readthedocs.io/latest/index.html + # URL to the code repository for this method (required). + repository: https://github.com/theislab/DRVI?tab=readme-ov-file + + + +# Metadata for your component +info: + # Which normalisation method this component prefers to use (required). + preferred_normalization: counts + +# Component-specific parameters (optional) +arguments: + - name: --n_hvg + type: integer + default: 2000 + description: Number of highly variable genes to use. + - name: --n_epochs + type: integer + default: 400 + description: Number of epochs + +# Resources required to run the component +resources: + # The script of your component (required) + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + + # Additional resources your script needs (optional) + # - type: file + # path: weights.pt + +engines: + # Specifications for the Docker image for this component. + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + # Add custom dependencies here (optional). For more information, see + # https://viash.io/reference/config/engines/docker/#setup . + setup: + - type: python + pypi: + - drvi-py==0.1.7 + - torch==2.3.0 + - torchvision==0.18.0 + # packages: + +runners: + # This platform allows running the component natively + - type: executable + # Allows turning the component into a Nextflow module / pipeline. + - type: nextflow + directives: + label: [midtime,midmem,lowcpu,gpu] diff --git a/src/methods/drvi/script.py b/src/methods/drvi/script.py new file mode 100644 index 00000000..f7f2f1ea --- /dev/null +++ b/src/methods/drvi/script.py @@ -0,0 +1,117 @@ +import anndata as ad +import scanpy as sc +import drvi +from drvi.model import DRVI +from drvi.utils.misc import hvg_batch +import pandas as pd +import numpy as np +import warnings +import sys +import scipy.sparse + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + 'input': 'resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad', + 'output': 'output.h5ad', + 'n_hvg': 2000, + 'n_epochs': 400 +} +meta = { + 'name': 'drvi' +} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata + +print('Reading input files', flush=True) +adata = read_anndata( + par['input'], + X='layers/counts', + obs='obs', + var='var', + uns='uns' +) +# Remove dataset with non-count values +counts = adata.X +import scipy.sparse + +if scipy.sparse.issparse(counts): + counts_dense = counts.toarray() +else: + counts_dense = counts + +if not np.allclose(counts_dense, np.round(counts_dense)): + warnings.warn("Non-integer values detected. Rounding to nearest integer.") + adata.X = np.round(counts_dense).astype(int) + +adata.layers["counts"] = adata.X.copy() + +if par["n_hvg"]: + print(f"Select top {par['n_hvg']} high variable genes", flush=True) + idx = adata.var["hvg_score"].to_numpy().argsort()[::-1][:par["n_hvg"]] + adata = adata[:, idx].copy() + +print('Train model with DRVI', flush=True) +# Setup data +DRVI.setup_anndata( + adata, + # DRVI accepts count data by default. + # Do not forget to change gene_likelihood if you provide a non-count data. + layer="counts", + # Always provide a list. DRVI can accept multiple covariates. + categorical_covariate_keys=["batch"], + # DRVI accepts count data by default. + # Set to false if you provide log-normalized data and use normal distribution (mse loss). + is_count_data=False, +) + +# construct the model +model = DRVI( + adata, + # Provide categorical covariates keys once again. Refer to advanced usages for more options. + categorical_covariates=["batch"], + n_latent=128, + # For encoder and decoder dims, provide a list of integers. + encoder_dims=[128, 128], + decoder_dims=[128, 128], +) +model + +# train the model +model.train( + max_epochs=par["n_epochs"], + early_stopping=False, + early_stopping_patience=20, + # mps + # accelerator="mps", devices=1, + # cpu + # accelerator="cpu", devices=1, + # gpu: no additional parameter + # + # No need to provide `plan_kwargs` if n_epochs >= 400. + plan_kwargs={ + "n_epochs_kl_warmup": par["n_epochs"], + }, +) + +embed = ad.AnnData(model.get_latent_representation(), obs=adata.obs) + +print("Store outputs", flush=True) +output = ad.AnnData( + obs=adata.obs.copy(), + var=adata.var.copy(), + obsm={ + "X_emb": model.get_latent_representation(), + }, + uns={ + "dataset_id": adata.uns.get("dataset_id", "unknown"), + "normalization_id": adata.uns.get("normalization_id", "unknown"), + "method_id": meta["name"], + }, +) + +print("Write output AnnData to file", flush=True) +output.write_h5ad(par['output'], compression='gzip')