From 2158882725c8b2bdb5909cd78171945bfb94d259 Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Fri, 8 Nov 2024 17:34:58 +0100 Subject: [PATCH] Add Geneformer (#6) * Add cxg_immune_cell_atlas as a test resource * Add SCimilarity component * Add SCimiliarity to benchmark workflow * Update script to extract model * Add SCimilarity model path to benchmark workflow * Add base_method API to disable tests for SCimilarity * Replace cxg_mouse_pancreas_atlas with cxg_immune_cell_atlas * Style SCimiliarity script * Remove test resources from SCimiliarity config * Fix file names in test resources state.yaml * Add scimilarity as dependency to benchmark workflow * Update compute environment * Update model file path * Create geneformer files * Set SCimilarity name in Python script * Adjust container settings Depend on base method config because of input model file * Download dictionary files in script * Prepare and tokenize data, attempt to embed * Store and output embedding * Add Geneformer to benchmark workflow * Add argument to select model version to use * Style Geneformer script * Make Geneformer inherit from base_method for tests --- src/methods/geneformer/config.vsh.yaml | 58 ++++++++ src/methods/geneformer/script.py | 154 ++++++++++++++++++++ src/methods/scimilarity/script.py | 2 +- src/workflows/run_benchmark/config.vsh.yaml | 1 + src/workflows/run_benchmark/main.nf | 1 + 5 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 src/methods/geneformer/config.vsh.yaml create mode 100644 src/methods/geneformer/script.py diff --git a/src/methods/geneformer/config.vsh.yaml b/src/methods/geneformer/config.vsh.yaml new file mode 100644 index 00000000..d571a4ad --- /dev/null +++ b/src/methods/geneformer/config.vsh.yaml @@ -0,0 +1,58 @@ +__merge__: /src/api/base_method.yaml + +name: geneformer +label: Geneformer +summary: Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes +description: | + Geneformer is a foundation transformer model pretrained on a large-scale + corpus of single cell transcriptomes to enable context-aware predictions in + network biology. For this task, Geneformer is used to create a batch-corrected + cell embedding. +references: + doi: + - 10.1038/s41586-023-06139-9 + - 10.1101/2024.08.16.608180 +links: + documentation: https://geneformer.readthedocs.io/en/latest/index.html + repository: https://huggingface.co/ctheodoris/Geneformer + +info: + preferred_normalization: counts + method_types: [embedding] + variants: + geneformer_12L_95M_i4096: + model: "gf-12L-95M-i4096" + geneformer_6L_30M_i2048: + model: "gf-6L-30M-i2048" + geneformer_12L_30M_i2048: + model: "gf-12L-30M-i2048" + geneformer_20L_95M_i4096: + model: "gf-20L-95M-i4096" + +arguments: + - name: "--model" + type: "string" + description: String representing the Geneformer model to use + choices: ["gf-6L-30M-i2048", "gf-12L-30M-i2048", "gf-12L-95M-i4096", "gf-20L-95M-i4096"] + default: "gf-12L-95M-i4096" + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + setup: + - type: python + pip: + - pyarrow<15.0.0a0,>=14.0.1 + - huggingface_hub + - git+https://huggingface.co/ctheodoris/Geneformer.git + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/geneformer/script.py b/src/methods/geneformer/script.py new file mode 100644 index 00000000..eeab4332 --- /dev/null +++ b/src/methods/geneformer/script.py @@ -0,0 +1,154 @@ +import os +import sys +from tempfile import TemporaryDirectory + +import anndata as ad +import numpy as np +import pandas as pd +from geneformer import EmbExtractor, TranscriptomeTokenizer +from huggingface_hub import hf_hub_download + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad", + "output": "output.h5ad", + "model": "gf-12L-95M-i4096", +} +meta = {"name": "geneformer"} +## VIASH END + +n_processors = os.cpu_count() + +print(">>> Reading input...", flush=True) +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata + +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + raise ValueError( + f"Geneformer can only be used with human data " + f"(dataset_organism == '{adata.uns['dataset_organism']}')" + ) + +is_ensembl = all(var_name.startswith("ENSG") for var_name in adata.var_names) +if not is_ensembl: + raise ValueError(f"Geneformer requires adata.var_names to contain ENSEMBL gene ids") + +print(f">>> Getting settings for model '{par['model']}'...", flush=True) +model_split = par["model"].split("-") +model_details = { + "layers": model_split[1], + "dataset": model_split[2], + "input_size": int(model_split[3][1:]), +} +print(model_details, flush=True) + +print(">>> Getting model dictionary files...", flush=True) +if model_details["dataset"] == "95M": + dictionaries_subfolder = "geneformer" +elif model_details["dataset"] == "30M": + dictionaries_subfolder = "geneformer/gene_dictionaries_30m" +else: + raise ValueError(f"Invalid model dataset: {model_details['dataset']}") +print(f"Dictionaries subfolder: '{dictionaries_subfolder}'") + +dictionary_files = { + "ensembl_mapping": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=dictionaries_subfolder, + filename=f"ensembl_mapping_dict_gc{model_details['dataset']}.pkl", + ), + "gene_median": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=dictionaries_subfolder, + filename=f"gene_median_dictionary_gc{model_details['dataset']}.pkl", + ), + "gene_name_id": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=dictionaries_subfolder, + filename=f"gene_name_id_dict_gc{model_details['dataset']}.pkl", + ), + "token": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=dictionaries_subfolder, + filename=f"token_dictionary_gc{model_details['dataset']}.pkl", + ), +} + +print(">>> Creating working directory...", flush=True) +work_dir = TemporaryDirectory() +input_dir = os.path.join(work_dir.name, "input") +os.makedirs(input_dir) +tokenized_dir = os.path.join(work_dir.name, "tokenized") +os.makedirs(tokenized_dir) +embedding_dir = os.path.join(work_dir.name, "embedding") +os.makedirs(embedding_dir) +print(f"Working directory: '{work_dir.name}'", flush=True) + +print(">>> Preparing data...", flush=True) +adata.var["ensembl_id"] = adata.var_names +adata.obs["n_counts"] = np.ravel(adata.X.sum(axis=1)) +adata.write_h5ad(os.path.join(input_dir, "input.h5ad")) +print(adata) + +print(">>> Tokenizing data...", flush=True) +special_token = model_details["dataset"] == "95M" +print(f"Input size: {model_details['input_size']}, Special token: {special_token}") +tokenizer = TranscriptomeTokenizer( + nproc=n_processors, + model_input_size=model_details["input_size"], + special_token=special_token, + gene_median_file=dictionary_files["gene_median"], + token_dictionary_file=dictionary_files["token"], + gene_mapping_file=dictionary_files["ensembl_mapping"], +) +tokenizer.tokenize_data(input_dir, tokenized_dir, "tokenized", file_format="h5ad") + +print(f">>> Getting model files for model '{par['model']}'...", flush=True) +model_files = { + "model": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=par["model"], + filename="model.safetensors", + ), + "config": hf_hub_download( + repo_id="ctheodoris/Geneformer", + subfolder=par["model"], + filename="config.json", + ), +} +model_dir = os.path.dirname(model_files["model"]) + +print(">>> Extracting embeddings...", flush=True) +embedder = EmbExtractor( + emb_mode="cell", max_ncells=None, token_dictionary_file=dictionary_files["token"] +) +embedder.extract_embs( + model_dir, + os.path.join(tokenized_dir, "tokenized.dataset"), + embedding_dir, + "embedding", +) +embedding = pd.read_csv(os.path.join(embedding_dir, "embedding.csv")).to_numpy() + +print(">>> Storing outputs...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedding, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print(">>> Writing output AnnData to file...", flush=True) +output.write_h5ad(par["output"], compression="gzip") +print(">>> Done!") diff --git a/src/methods/scimilarity/script.py b/src/methods/scimilarity/script.py index 2da1790e..761d59a5 100644 --- a/src/methods/scimilarity/script.py +++ b/src/methods/scimilarity/script.py @@ -14,7 +14,7 @@ "model": "model_v1.1", } meta = { - "name": "scvi", + "name": "scimilarity", } ## VIASH END diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 51e482ab..d3cc2b55 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -85,6 +85,7 @@ dependencies: - name: methods/batchelor_mnn_correct - name: methods/bbknn - name: methods/combat + - name: methods/geneformer - name: methods/harmony - name: methods/harmonypy - name: methods/liger diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 69322a1a..89564bd5 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -20,6 +20,7 @@ methods = [ batchelor_mnn_correct, bbknn, combat, + geneformer, harmony, harmonypy, liger,