diff --git a/scripts/run_benchmark/run_full_local.sh b/scripts/run_benchmark/run_full_local.sh index 20e434b3..84a60b60 100755 --- a/scripts/run_benchmark/run_full_local.sh +++ b/scripts/run_benchmark/run_full_local.sh @@ -26,7 +26,7 @@ input_states: resources/datasets/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" -settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}' +settings: '{"methods_exclude": ["uce", "scgpt_finetuned", "cellplm"]}' HERE # run the benchmark diff --git a/scripts/run_benchmark/run_test_local.sh b/scripts/run_benchmark/run_test_local.sh index 85e39583..3d9de958 100755 --- a/scripts/run_benchmark/run_test_local.sh +++ b/scripts/run_benchmark/run_test_local.sh @@ -21,7 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" -settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}' +settings: '{"methods_exclude": ["uce", "scgpt_finetuned", "cellplm]}' HERE nextflow run . \ diff --git a/src/methods/cellplm/config.vsh.yaml b/src/methods/cellplm/config.vsh.yaml new file mode 100755 index 00000000..df8b9e71 --- /dev/null +++ b/src/methods/cellplm/config.vsh.yaml @@ -0,0 +1,51 @@ +__merge__: ../../api/base_method.yaml + +name: cellplm + +label: CellPLM + +summary: "A foundation model pre-trained with cells as tokens." + +description: | + CellPLM is a pre-trained language model specifically designed for single-cell analysis that leverages the principles of natural language processing (NLP) to understand and process single-cell gene expression data. +references: + doi: + - 10.1101/2023.10.03.560734 +links: + documentation: https://github.com/OmicsML/CellPLM/tree/main/tutorials + repository: https://github.com/OmicsML/CellPLM + +info: + method_types: [embedding] + preferred_normalization: counts + +arguments: + - name: --model_name + type: string + description: String giving the CellPLM model to use + choices: ["20231027_85M", "20230926_85M"] + default: "20231027_85M" + - name: --model + type: file + description: Path to the directory containing CellPLM model files or a .zip/.tar.gz archive + required: true + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + - path: /src/utils/exit_codes.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + pypi: + - cellplm + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/cellplm/script.py b/src/methods/cellplm/script.py new file mode 100644 index 00000000..6e6a20c8 --- /dev/null +++ b/src/methods/cellplm/script.py @@ -0,0 +1,104 @@ +import os +import sys +import tarfile +import tempfile +import zipfile + +import anndata as ad +import torch +from CellPLM.pipeline.cell_embedding import CellEmbeddingPipeline +from CellPLM.utils import set_seed + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. + +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model": "20231027_85M", +} +meta = {"name": "cellplm"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from exit_codes import exit_non_applicable +from read_anndata_partial import read_anndata + +set_seed(24) +device = "cuda" if torch.cuda.is_available() else "cpu" +if device == "cpu": + import warnings + + warnings.warn("Loading CellPLM models requires a GPU, this run will fail") + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + exit_non_applicable( + f"CellPLM can only be used with human data " + f'(dataset_organism == "{adata.uns["dataset_organism"]}")' + ) + +print(adata, flush=True) + +print("\n>>> Getting model files...", flush=True) +# Available from https://www.dropbox.com/scl/fo/i5rmxgtqzg7iykt2e9uqm/h/ckpt?dl=0&subfolder_nav_tracking=1 +if os.path.isdir(par["model"]): + model_temp = None + model_dir = par["model"] +else: + model_temp = tempfile.TemporaryDirectory() + model_dir = model_temp.name + + if zipfile.is_zipfile(par["model"]): + print("Extracting CellPLM models from .zip...", flush=True) + with zipfile.ZipFile(par["model"], "r") as zip_file: + zip_file.extractall(model_dir) + elif tarfile.is_tarfile(par["model"]) and par["model"].endswith(".tar.gz"): + print("Extracting CellPLM models from .tar.gz...", flush=True) + with tarfile.open(par["model"], "r:gz") as tar_file: + tar_file.extractall(model_dir) + model_dir = os.path.join(model_dir, os.listdir(model_dir)[0]) + else: + raise ValueError( + "The 'model' argument should be a directory a .zip file or a .tar.gz file" + ) + +print(f"Model directory: '{model_dir}'", flush=True) + +print("\n>>> Creating embedding model pipeline...", flush=True) +pipeline = CellEmbeddingPipeline( + pretrain_prefix=par["model_name"], pretrain_directory=model_dir +) + +print("\n>>> Embedding data...", flush=True) +embedding = pipeline.predict(adata, device=device) +embedding = embedding.cpu().numpy() + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedding, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +if model_temp is not None: + print("\n>>> Cleaning up temporary directories...", flush=True) + model_temp.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 5cd9b339..c53035fd 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -91,6 +91,7 @@ dependencies: - name: methods/batchelor_fastmnn - name: methods/batchelor_mnn_correct - name: methods/bbknn + - name: methods/cellplm - name: methods/combat - name: methods/geneformer - name: methods/harmony diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index e4e1ad32..5c34f1c0 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -19,6 +19,9 @@ methods = [ batchelor_fastmnn, batchelor_mnn_correct, bbknn, + cellplm.run( + args: [model: file("s3://openproblems-work/cache/cellplm-ckpt.zip")] + ), combat, geneformer, harmony,