Skip to content

Commit

Permalink
Add Geneformer (#6)
Browse files Browse the repository at this point in the history
* Add cxg_immune_cell_atlas as a test resource

* Add SCimilarity component

* Add SCimiliarity to benchmark workflow

* Update script to extract model

* Add SCimilarity model path to benchmark workflow

* Add base_method API to disable tests for SCimilarity

* Replace cxg_mouse_pancreas_atlas with cxg_immune_cell_atlas

* Style SCimiliarity script

* Remove test resources from SCimiliarity config

* Fix file names in test resources state.yaml

* Add scimilarity as dependency to benchmark workflow

* Update compute environment

* Update model file path

* Create geneformer files

* Set SCimilarity name in Python script

* Adjust container settings

Depend on base method config because of input model file

* Download dictionary files in script

* Prepare and tokenize data, attempt to embed

* Store and output embedding

* Add Geneformer to benchmark workflow

* Add argument to select model version to use

* Style Geneformer script

* Make Geneformer inherit from base_method for tests
  • Loading branch information
lazappi authored Nov 8, 2024
1 parent 4e21e40 commit 2158882
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 1 deletion.
58 changes: 58 additions & 0 deletions src/methods/geneformer/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
__merge__: /src/api/base_method.yaml

name: geneformer
label: Geneformer
summary: Geneformer is a foundation transformer model pretrained on a large-scale corpus of single cell transcriptomes
description: |
Geneformer is a foundation transformer model pretrained on a large-scale
corpus of single cell transcriptomes to enable context-aware predictions in
network biology. For this task, Geneformer is used to create a batch-corrected
cell embedding.
references:
doi:
- 10.1038/s41586-023-06139-9
- 10.1101/2024.08.16.608180
links:
documentation: https://geneformer.readthedocs.io/en/latest/index.html
repository: https://huggingface.co/ctheodoris/Geneformer

info:
preferred_normalization: counts
method_types: [embedding]
variants:
geneformer_12L_95M_i4096:
model: "gf-12L-95M-i4096"
geneformer_6L_30M_i2048:
model: "gf-6L-30M-i2048"
geneformer_12L_30M_i2048:
model: "gf-12L-30M-i2048"
geneformer_20L_95M_i4096:
model: "gf-20L-95M-i4096"

arguments:
- name: "--model"
type: "string"
description: String representing the Geneformer model to use
choices: ["gf-6L-30M-i2048", "gf-12L-30M-i2048", "gf-12L-95M-i4096", "gf-20L-95M-i4096"]
default: "gf-12L-95M-i4096"

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
setup:
- type: python
pip:
- pyarrow<15.0.0a0,>=14.0.1
- huggingface_hub
- git+https://huggingface.co/ctheodoris/Geneformer.git

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
154 changes: 154 additions & 0 deletions src/methods/geneformer/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import os
import sys
from tempfile import TemporaryDirectory

import anndata as ad
import numpy as np
import pandas as pd
from geneformer import EmbExtractor, TranscriptomeTokenizer
from huggingface_hub import hf_hub_download

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
"input": "resources_test/task_batch_integration/cxg_immune_cell_atlas/dataset.h5ad",
"output": "output.h5ad",
"model": "gf-12L-95M-i4096",
}
meta = {"name": "geneformer"}
## VIASH END

n_processors = os.cpu_count()

print(">>> Reading input...", flush=True)
sys.path.append(meta["resources_dir"])
from read_anndata_partial import read_anndata

adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns")

if adata.uns["dataset_organism"] != "homo_sapiens":
raise ValueError(
f"Geneformer can only be used with human data "
f"(dataset_organism == '{adata.uns['dataset_organism']}')"
)

is_ensembl = all(var_name.startswith("ENSG") for var_name in adata.var_names)
if not is_ensembl:
raise ValueError(f"Geneformer requires adata.var_names to contain ENSEMBL gene ids")

print(f">>> Getting settings for model '{par['model']}'...", flush=True)
model_split = par["model"].split("-")
model_details = {
"layers": model_split[1],
"dataset": model_split[2],
"input_size": int(model_split[3][1:]),
}
print(model_details, flush=True)

print(">>> Getting model dictionary files...", flush=True)
if model_details["dataset"] == "95M":
dictionaries_subfolder = "geneformer"
elif model_details["dataset"] == "30M":
dictionaries_subfolder = "geneformer/gene_dictionaries_30m"
else:
raise ValueError(f"Invalid model dataset: {model_details['dataset']}")
print(f"Dictionaries subfolder: '{dictionaries_subfolder}'")

dictionary_files = {
"ensembl_mapping": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"ensembl_mapping_dict_gc{model_details['dataset']}.pkl",
),
"gene_median": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"gene_median_dictionary_gc{model_details['dataset']}.pkl",
),
"gene_name_id": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"gene_name_id_dict_gc{model_details['dataset']}.pkl",
),
"token": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=dictionaries_subfolder,
filename=f"token_dictionary_gc{model_details['dataset']}.pkl",
),
}

print(">>> Creating working directory...", flush=True)
work_dir = TemporaryDirectory()
input_dir = os.path.join(work_dir.name, "input")
os.makedirs(input_dir)
tokenized_dir = os.path.join(work_dir.name, "tokenized")
os.makedirs(tokenized_dir)
embedding_dir = os.path.join(work_dir.name, "embedding")
os.makedirs(embedding_dir)
print(f"Working directory: '{work_dir.name}'", flush=True)

print(">>> Preparing data...", flush=True)
adata.var["ensembl_id"] = adata.var_names
adata.obs["n_counts"] = np.ravel(adata.X.sum(axis=1))
adata.write_h5ad(os.path.join(input_dir, "input.h5ad"))
print(adata)

print(">>> Tokenizing data...", flush=True)
special_token = model_details["dataset"] == "95M"
print(f"Input size: {model_details['input_size']}, Special token: {special_token}")
tokenizer = TranscriptomeTokenizer(
nproc=n_processors,
model_input_size=model_details["input_size"],
special_token=special_token,
gene_median_file=dictionary_files["gene_median"],
token_dictionary_file=dictionary_files["token"],
gene_mapping_file=dictionary_files["ensembl_mapping"],
)
tokenizer.tokenize_data(input_dir, tokenized_dir, "tokenized", file_format="h5ad")

print(f">>> Getting model files for model '{par['model']}'...", flush=True)
model_files = {
"model": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=par["model"],
filename="model.safetensors",
),
"config": hf_hub_download(
repo_id="ctheodoris/Geneformer",
subfolder=par["model"],
filename="config.json",
),
}
model_dir = os.path.dirname(model_files["model"])

print(">>> Extracting embeddings...", flush=True)
embedder = EmbExtractor(
emb_mode="cell", max_ncells=None, token_dictionary_file=dictionary_files["token"]
)
embedder.extract_embs(
model_dir,
os.path.join(tokenized_dir, "tokenized.dataset"),
embedding_dir,
"embedding",
)
embedding = pd.read_csv(os.path.join(embedding_dir, "embedding.csv")).to_numpy()

print(">>> Storing outputs...", flush=True)
output = ad.AnnData(
obs=adata.obs[[]],
var=adata.var[[]],
obsm={
"X_emb": embedding,
},
uns={
"dataset_id": adata.uns["dataset_id"],
"normalization_id": adata.uns["normalization_id"],
"method_id": meta["name"],
},
)
print(output)

print(">>> Writing output AnnData to file...", flush=True)
output.write_h5ad(par["output"], compression="gzip")
print(">>> Done!")
2 changes: 1 addition & 1 deletion src/methods/scimilarity/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"model": "model_v1.1",
}
meta = {
"name": "scvi",
"name": "scimilarity",
}
## VIASH END

Expand Down
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ dependencies:
- name: methods/batchelor_mnn_correct
- name: methods/bbknn
- name: methods/combat
- name: methods/geneformer
- name: methods/harmony
- name: methods/harmonypy
- name: methods/liger
Expand Down
1 change: 1 addition & 0 deletions src/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ methods = [
batchelor_mnn_correct,
bbknn,
combat,
geneformer,
harmony,
harmonypy,
liger,
Expand Down

0 comments on commit 2158882

Please sign in to comment.