From 6191552ab6734ced8565aa8426d07ed0ce3da26f Mon Sep 17 00:00:00 2001 From: Luke Zappia Date: Fri, 17 Jan 2025 08:38:00 +0100 Subject: [PATCH] Add fine-tuned scGPT (#17) * Move scgpt to scgpt_zeroshot * Create scgpt_finetuned component * Fix name in scgpt_zeroshot * Add scgpt_finetuned pre-processing * Add scgpt_finetuned fine-tuning and embedding * Add scgpt_finetuned to benchmark workflow * Style scgpt_finetuned files * Disable scgpt_finetuned tests Requires GPU * Remove docker run args from scgpt_finetuned * Add model arguments to scgpt_finetuned Fixed some things in scgpt_zeroshot... * Exclude scgpt_finetuned from full benchmark runs Need to sort out GPU infrastructure --------- Co-authored-by: Robrecht Cannoodt --- scripts/run_benchmark/run_full_local.sh | 2 +- scripts/run_benchmark/run_full_seqeracloud.sh | 1 + scripts/run_benchmark/run_test_local.sh | 2 +- src/methods/scgpt_finetuned/config.vsh.yaml | 65 +++ .../scgpt_finetuned/scgpt_functions.py | 288 +++++++++++++ src/methods/scgpt_finetuned/script.py | 396 ++++++++++++++++++ .../{scgpt => scgpt_zeroshot}/config.vsh.yaml | 13 +- .../{scgpt => scgpt_zeroshot}/script.py | 0 src/workflows/run_benchmark/config.vsh.yaml | 3 +- src/workflows/run_benchmark/main.nf | 5 +- 10 files changed, 765 insertions(+), 10 deletions(-) create mode 100644 src/methods/scgpt_finetuned/config.vsh.yaml create mode 100644 src/methods/scgpt_finetuned/scgpt_functions.py create mode 100644 src/methods/scgpt_finetuned/script.py rename src/methods/{scgpt => scgpt_zeroshot}/config.vsh.yaml (90%) rename src/methods/{scgpt => scgpt_zeroshot}/script.py (100%) diff --git a/scripts/run_benchmark/run_full_local.sh b/scripts/run_benchmark/run_full_local.sh index d823d79e..20e434b3 100755 --- a/scripts/run_benchmark/run_full_local.sh +++ b/scripts/run_benchmark/run_full_local.sh @@ -26,7 +26,7 @@ input_states: resources/datasets/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" -settings: '{"methods_exclude": ["uce"]}' +settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}' HERE # run the benchmark diff --git a/scripts/run_benchmark/run_full_seqeracloud.sh b/scripts/run_benchmark/run_full_seqeracloud.sh index bd88d0dc..91f392c2 100755 --- a/scripts/run_benchmark/run_full_seqeracloud.sh +++ b/scripts/run_benchmark/run_full_seqeracloud.sh @@ -18,6 +18,7 @@ input_states: s3://openproblems-data/resources/task_batch_integration/datasets/* rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" +settings: '{"methods_exclude": ["scgpt_finetuned"]}' HERE tw launch https://github.com/openproblems-bio/task_batch_integration.git \ diff --git a/scripts/run_benchmark/run_test_local.sh b/scripts/run_benchmark/run_test_local.sh index 2b72eeed..85e39583 100755 --- a/scripts/run_benchmark/run_test_local.sh +++ b/scripts/run_benchmark/run_test_local.sh @@ -21,7 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml rename_keys: 'input_dataset:output_dataset;input_solution:output_solution' output_state: "state.yaml" publish_dir: "$publish_dir" -settings: '{"methods_exclude": ["uce"]}' +settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}' HERE nextflow run . \ diff --git a/src/methods/scgpt_finetuned/config.vsh.yaml b/src/methods/scgpt_finetuned/config.vsh.yaml new file mode 100644 index 00000000..bc5eb6cc --- /dev/null +++ b/src/methods/scgpt_finetuned/config.vsh.yaml @@ -0,0 +1,65 @@ +__merge__: ../../api/base_method.yaml + +name: scgpt_finetuned +label: scGPT (fine-tuned) +summary: "A foundation model for single-cell biology (fine-tuned)" +description: | + scGPT is a foundation model for single-cell biology based on a generative + pre-trained transformer and trained on a repository of over 33 million cells. + + Here, we fine-tune the pre-trained model for the batch integration task. +references: + doi: + - 10.1038/s41592-024-02201-0 +links: + documentation: https://scgpt.readthedocs.io/en/latest/ + repository: https://github.com/bowang-lab/scGPT + +info: + method_types: [embedding] + preferred_normalization: counts + variants: + scgpt_finetuned_default: + +arguments: + - name: --model_name + type: string + description: String giving the name of the scGPT model to use + choices: ["scGPT_human", "scGPT_CP"] + default: "scGPT_human" + - name: --model + type: file + description: | + Path to the directory containing the scGPT model specified by model_name + or a .zip/.tar.gz archive to extract. If not given the model will be + downloaded. + required: false + - name: --n_hvg + type: integer + default: 3000 + description: Number of highly variable genes to use. + +resources: + - type: python_script + path: script.py + - path: /src/utils/read_anndata_partial.py + - path: scgpt_functions.py + +engines: + - type: docker + image: openproblems/base_pytorch_nvidia:1.0.0 + # TODO: Try to find working installation of flash attention (flash-attn<1.0.5) + setup: + - type: python + pypi: + - gdown + - scgpt # Install from PyPI to get dependencies + - type: docker + # Force re-installing from GitHub to get bug fixes + run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git + +runners: + - type: executable + - type: nextflow + directives: + label: [midtime, midmem, midcpu, gpu] diff --git a/src/methods/scgpt_finetuned/scgpt_functions.py b/src/methods/scgpt_finetuned/scgpt_functions.py new file mode 100644 index 00000000..c03bc45c --- /dev/null +++ b/src/methods/scgpt_finetuned/scgpt_functions.py @@ -0,0 +1,288 @@ +import time +import warnings + +import numpy as np +import scgpt +import torch + + +def prepare_data( + tokenized_train, + tokenized_valid, + train_batch_labels, + valid_batch_labels, + hyperparameters, + model_settings, + epoch, +): + masked_values_train = scgpt.tokenizer.random_mask_value( + tokenized_train["values"], + mask_ratio=hyperparameters["mask_ratio"], + mask_value=model_settings["mask_value"], + pad_value=model_settings["pad_value"], + ) + masked_values_valid = scgpt.tokenizer.random_mask_value( + tokenized_valid["values"], + mask_ratio=hyperparameters["mask_ratio"], + mask_value=model_settings["mask_value"], + pad_value=model_settings["pad_value"], + ) + scgpt.logger.info( + f"Random masking at epoch {epoch:3d}," + f"ratio of masked values in train: {(masked_values_train == model_settings['mask_value']).sum() / (masked_values_train - model_settings['pad_value']).count_nonzero():.4f}" + ) + + input_gene_ids_train, input_gene_ids_valid = ( + tokenized_train["genes"], + tokenized_valid["genes"], + ) + input_values_train, input_values_valid = masked_values_train, masked_values_valid + target_values_train, target_values_valid = ( + tokenized_train["values"], + tokenized_valid["values"], + ) + + tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long() + tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long() + + if model_settings["per_seq_batch_sample"]: + train_sort_ids = np.argsort(train_batch_labels) + input_gene_ids_train = input_gene_ids_train[train_sort_ids] + input_values_train = input_values_train[train_sort_ids] + target_values_train = target_values_train[train_sort_ids] + tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids] + + valid_sort_ids = np.argsort(valid_batch_labels) + input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids] + input_values_valid = input_values_valid[valid_sort_ids] + target_values_valid = target_values_valid[valid_sort_ids] + tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids] + + train_data_pt = { + "gene_ids": input_gene_ids_train, + "values": input_values_train, + "target_values": target_values_train, + "batch_labels": tensor_batch_labels_train, + } + valid_data_pt = { + "gene_ids": input_gene_ids_valid, + "values": input_values_valid, + "target_values": target_values_valid, + "batch_labels": tensor_batch_labels_valid, + } + + return train_data_pt, valid_data_pt + + +class SeqDataset(torch.utils.data.Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return self.data["gene_ids"].shape[0] + + def __getitem__(self, idx): + return {k: v[idx] for k, v in self.data.items()} + + +def prepare_dataloader( + data_pt, + batch_size, + shuffle, + intra_domain_shuffle, + drop_last, + num_workers, + per_seq_batch_sample, +): + dataset = SeqDataset(data_pt) + + if per_seq_batch_sample: + # Find the indices of samples in each seq batch + subsets = [] + batch_labels_array = data_pt["batch_labels"].numpy() + for batch_label in np.unique(batch_labels_array): + batch_indices = np.where(batch_labels_array == batch_label)[0].tolist() + subsets.append(batch_indices) + data_loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_sampler=scgpt.SubsetsBatchSampler( + subsets, + batch_size, + intra_subset_shuffle=intra_domain_shuffle, + inter_subset_shuffle=shuffle, + drop_last=drop_last, + ), + num_workers=num_workers, + pin_memory=True, + ) + return data_loader + + data_loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=True, + ) + return data_loader + + +def train( + model, + loader, + scaler, + optimizer, + scheduler, + vocab, + criterion, + criterion_dab, + hyperparameters, + model_settings, + device, + epoch, +): + model.train() + + total_loss, total_mse, total_gepc = 0.0, 0.0, 0.0 + total_error = 0.0 + log_interval = hyperparameters["log_interval"] + start_time = time.time() + + num_batches = len(loader) + for batch, batch_data in enumerate(loader): + input_gene_ids = batch_data["gene_ids"].to(device) + input_values = batch_data["values"].to(device) + target_values = batch_data["target_values"].to(device) + batch_labels = batch_data["batch_labels"].to(device) + + src_key_padding_mask = input_gene_ids.eq(vocab[model_settings["pad_token"]]) + with torch.cuda.amp.autocast(enabled=hyperparameters["amp"]): + output_dict = model( + input_gene_ids, + input_values, + src_key_padding_mask=src_key_padding_mask, + batch_labels=batch_labels if model_settings["DSBN"] else None, + MVC=hyperparameters["GEPC"], + ECS=hyperparameters["ecs_thres"] > 0, + ) + + masked_positions = input_values.eq( + model_settings["mask_value"] + ) # the postions to predict + loss = loss_mse = criterion( + output_dict["mlm_output"], target_values, masked_positions + ) + if model_settings["explicit_zero_prob"]: + loss_zero_log_prob = scgpt.loss.criterion_neg_log_bernoulli( + output_dict["mlm_zero_probs"], target_values, masked_positions + ) + loss = loss + loss_zero_log_prob + if hyperparameters["GEPC"]: + loss_gepc = criterion( + output_dict["mvc_output"], target_values, masked_positions + ) + loss = loss + loss_gepc + if hyperparameters["GEPC"] and model_settings["explicit_zero_prob"]: + loss_gepc_zero_log_prob = scgpt.loss.criterion_neg_log_bernoulli( + output_dict["mvc_zero_probs"], target_values, masked_positions + ) + loss = loss + loss_gepc_zero_log_prob + if hyperparameters["ecs_thres"] > 0: + loss_ecs = 10 * output_dict["loss_ecs"] + loss = loss + loss_ecs + loss_dab = criterion_dab(output_dict["dab_output"], batch_labels) + loss = loss + hyperparameters["dab_weight"] * loss_dab + + model.zero_grad() + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings("always") + torch.nn.utils.clip_grad_norm_( + model.parameters(), + 1.0, + error_if_nonfinite=False if scaler.is_enabled() else True, + ) + if len(w) > 0: + scgpt.logger.warning( + f"Found infinite gradient. This may be caused by the gradient " + f"scaler. The current scale is {scaler.get_scale()}. This warning " + "can be ignored if no longer occurs after autoscaling of the scaler." + ) + scaler.step(optimizer) + scaler.update() + + with torch.no_grad(): + mre = scgpt.loss.masked_relative_error( + output_dict["mlm_output"], target_values, masked_positions + ) + + total_loss += loss.item() + total_mse += loss_mse.item() + total_gepc += loss_gepc.item() if hyperparameters["GEPC"] else 0.0 + total_error += mre.item() + if batch % log_interval == 0 and batch > 0: + lr = scheduler.get_last_lr()[0] + ms_per_batch = (time.time() - start_time) * 1000 / log_interval + cur_loss = total_loss / log_interval + cur_mse = total_mse / log_interval + cur_gepc = total_gepc / log_interval if hyperparameters["GEPC"] else 0.0 + cur_error = total_error / log_interval + scgpt.logger.info( + f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | " + f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | " + f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} | mre {cur_error:5.2f} |" + + (f"gepc {cur_gepc:5.2f} |" if hyperparameters["GEPC"] else "") + ) + total_loss = 0 + total_mse = 0 + total_gepc = 0 + total_error = 0 + start_time = time.time() + + +def evaluate( + model, + loader, + vocab, + criterion, + criterion_dab, + hyperparameters, + model_settings, + device, +): + model.eval() + total_loss = 0.0 + total_error = 0.0 + total_dab = 0.0 + total_num = 0 + with torch.no_grad(): + for batch_data in loader: + input_gene_ids = batch_data["gene_ids"].to(device) + input_values = batch_data["values"].to(device) + target_values = batch_data["target_values"].to(device) + batch_labels = batch_data["batch_labels"].to(device) + + src_key_padding_mask = input_gene_ids.eq(vocab[model_settings["pad_token"]]) + with torch.cuda.amp.autocast(enabled=hyperparameters["amp"]): + output_dict = model( + input_gene_ids, + input_values, + src_key_padding_mask=src_key_padding_mask, + batch_labels=batch_labels if model_settings["DSBN"] else None, + ) + output_values = output_dict["mlm_output"] + + masked_positions = input_values.eq(model_settings["mask_value"]) + loss = criterion(output_values, target_values, masked_positions) + loss_dab = criterion_dab(output_dict["dab_output"], batch_labels) + + total_loss += loss.item() * len(input_gene_ids) + total_error += scgpt.loss.masked_relative_error( + output_values, target_values, masked_positions + ).item() * len(input_gene_ids) + total_dab += loss_dab.item() * len(input_gene_ids) + total_num += len(input_gene_ids) + + return total_loss / total_num, total_error / total_num diff --git a/src/methods/scgpt_finetuned/script.py b/src/methods/scgpt_finetuned/script.py new file mode 100644 index 00000000..7bee08bf --- /dev/null +++ b/src/methods/scgpt_finetuned/script.py @@ -0,0 +1,396 @@ +import copy +import json +import os +import shutil +import sys +import tempfile +import time +import zipfile +import tarfile + +import anndata as ad +import gdown +import numpy as np +import scgpt +import torch +from sklearn.model_selection import train_test_split + +## VIASH START +# Note: this section is auto-generated by viash at runtime. To edit it, make changes +# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`. +par = { + "input": "resources_test/.../input.h5ad", + "output": "output.h5ad", + "model_name": "scGPT_human", + "model": "scGPT_human", + "n_hvg": 3000, +} +meta = {"name": "scgpt"} +## VIASH END + +sys.path.append(meta["resources_dir"]) +from read_anndata_partial import read_anndata +from scgpt_functions import evaluate, prepare_data, prepare_dataloader, train + +print(f"====== scGPT version {scgpt.__version__} ======", flush=True) + +print("\n>>> Reading input files...", flush=True) +print(f"Input H5AD file: '{par['input']}'", flush=True) +adata = read_anndata(par["input"], X="layers/counts", obs="obs", var="var", uns="uns") + +if adata.uns["dataset_organism"] != "homo_sapiens": + raise ValueError( + f"scGPT can only be used with human data " + f"(dataset_organism == \"{adata.uns['dataset_organism']}\")" + ) + +adata.obs["str_batch"] = adata.obs["batch"].astype(str) +adata.obs["batch_id"] = adata.obs["str_batch"].astype("category").cat.codes.values +adata.var["feature_id"] = adata.var_names +adata.var_names = adata.var["feature_name"] + +print(adata, flush=True) + +if par["model"] is None: + print(f"\n>>> Downloading '{par['model_name']}' model...", flush=True) + model_drive_ids = { + "scGPT_human": "1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y", + "scGPT_CP": "1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB", + } + drive_path = ( + f"https://drive.google.com/drive/folders/{model_drive_ids[par['model_name']]}" + ) + model_temp = tempfile.TemporaryDirectory() + model_dir = model_temp.name + print(f"Downloading from '{drive_path}'", flush=True) + gdown.download_folder(drive_path, output=model_dir, quiet=True) +else: + if os.path.isdir(par["model"]): + print(f"\n>>> Using model directory...", flush=True) + model_temp = None + model_dir = par["model"] + else: + model_temp = tempfile.TemporaryDirectory() + model_dir = model_temp.name + + if zipfile.is_zipfile(par["model"]): + print(f"\n>>> Extracting model from .zip...", flush=True) + print(f".zip path: '{par['model']}'", flush=True) + with zipfile.ZipFile(par["model"], "r") as zip_file: + zip_file.extractall(model_dir) + elif tarfile.is_tarfile(par["model"]) and par["model"].endswith( + ".tar.gz" + ): + print(f"\n>>> Extracting model from .tar.gz...", flush=True) + print(f".tar.gz path: '{par['model']}'", flush=True) + with tarfile.open(par["model"], "r:gz") as tar_file: + tar_file.extractall(model_dir) + model_dir = os.path.join(model_dir, os.listdir(model_dir)[0]) + else: + raise ValueError( + f"The 'model' argument should be a directory a .zip file or a .tar.gz file" + ) + +model_config_file = f"{model_dir}/args.json" +model_file = f"{model_dir}/best_model.pt" +vocab_file = f"{model_dir}/vocab.json" +print(f"Model directory: '{model_dir}'", flush=True) +print(f"Model config file: '{model_config_file}'", flush=True) +print(f"Model file: '{model_file}'", flush=True) +print(f"Model vocabulary file: '{vocab_file}'", flush=True) + +print("\n>>> Loading model configuration...", flush=True) +model_settings = { + # Input and preprocessing + "pad_token": "", + "special_tokens": ["", "", ""], + "mask_value": -1, + "pad_value": -2, + "n_input_bins": 51, + # Other settings + "n_hvg": par["n_hvg"], + "max_seq_len": par["n_hvg"] + 1, + "per_seq_batch_sample": True, + "DSBN": True, + "explicit_zero_prob": True, +} +print("Model settings:", flush=True) +for key, value in model_settings.items(): + print(f"\t{key}: {value}", flush=True) +vocab = scgpt.tokenizer.gene_tokenizer.GeneVocab.from_file(vocab_file) +for token in model_settings["special_tokens"]: + if token not in vocab: + vocab.add_token(token) +adata.var["id_in_vocab"] = [1 if gene in vocab else -1 for gene in adata.var_names] +gene_ids_in_vocab = np.array(adata.var["id_in_vocab"]) +scgpt.logger.info( + f"Matched {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes in vocabulary of {len(vocab)}", +) +adata = adata[:, adata.var["id_in_vocab"] >= 0].copy() +with open(model_config_file, "r") as f: + pretrained_config = json.load(f) + +model_config = { + "embsize": pretrained_config["embsize"], + "nheads": pretrained_config["nheads"], + "d_hid": pretrained_config["d_hid"], + "nlayers": pretrained_config["nlayers"], + "n_layers_cls": pretrained_config["n_layers_cls"], +} +print("Model configuration:", flush=True) +for key, value in model_config.items(): + print(f"\t{key}: {value}", flush=True) + +print("\n>>> Preprocessing data...", flush=True) +preprocessor = scgpt.preprocess.Preprocessor( + use_key="X", # The key in adata.layers to use as raw data + filter_gene_by_counts=3, # Number of counts for filtering genes + filter_cell_by_counts=False, # Number of counts for filtering cells + normalize_total=1e4, # Whether to normalize the raw data and to what sum + result_normed_key="X_normed", # The key in adata.layers to store the normalized data + log1p=True, # Whether to log1p the normalized data + result_log1p_key="X_log1p", # The key in adata.layers to store the log1p data + subset_hvg=model_settings[ + "n_hvg" + ], # Whether to subset the raw data to highly variable genes and to what number + hvg_flavor="seurat_v3", # The flavor of highly variable gene selection + binning=model_settings[ + "n_input_bins" + ], # Whether to bin the raw data and to what number of bins + result_binned_key="X_binned", # The key in adata.layers to store the binned data +) +preprocessor(adata, batch_key="str_batch") +print(adata, flush=True) + +print("\n>>> Splitting and tokenizing data...", flush=True) +celltype_labels = np.array(adata.obs["cell_type"].to_list()) +( + train_data, + valid_data, + train_celltype_labels, + valid_celltype_labels, + train_batch_labels, + valid_batch_labels, +) = train_test_split( + adata.X.A, + celltype_labels, + np.array(adata.obs["batch_id"].tolist()), + test_size=0.1, + shuffle=True, +) + +vocab.set_default_index(vocab[""]) +gene_ids = np.array(vocab(adata.var_names.tolist()), dtype=int) +tokenized_train = scgpt.tokenizer.tokenize_and_pad_batch( + train_data, + gene_ids, + max_len=model_settings["max_seq_len"], + vocab=vocab, + pad_token=model_settings["pad_token"], + pad_value=model_settings["pad_value"], + append_cls=True, # Append token at the beginning + include_zero_gene=True, +) +scgpt.logger.info( + f"Number of training samples: {tokenized_train['genes'].shape[0]}, " + f"\n\tFeature length: {tokenized_train['genes'].shape[1]}" +) +tokenized_valid = scgpt.tokenizer.tokenize_and_pad_batch( + valid_data, + gene_ids, + max_len=model_settings["max_seq_len"], + vocab=vocab, + pad_token=model_settings["pad_token"], + pad_value=model_settings["pad_value"], + append_cls=True, # Append token at the beginning + include_zero_gene=True, +) +scgpt.logger.info( + f"Number of validation samples: {tokenized_valid['genes'].shape[0]}, " + f"\n\tFeature length: {tokenized_valid['genes'].shape[1]}" +) + +print("\n>>> Loading pre-trained model...", flush=True) +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"Using '{device}' device") + +hyperparameters = { + "n_tokens": len(vocab), + "GEPC": True, # Gene expression modelling for cell objective + "ecs_thres": 0.8, # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable + "dab_weight": 1.0, # DAR objective weight for batch correction + "mask_ratio": 0.4, + "epochs": 15, + "lr": 1e-4, + "batch_size": 64, + "dropout": 0.2, + "schedule_ratio": 0.9, # Learning rate decay + "log_interval": 100, + "fast_transformer": False, # TODO: Set True if flash-attn is installed + "pre_norm": False, + "amp": True, # Automatic Mixed Precision +} +print("Hyperparameters:", flush=True) +for key, value in hyperparameters.items(): + print(f"\t{key}: {value}", flush=True) +model = scgpt.model.TransformerModel( + hyperparameters["n_tokens"], + model_config["embsize"], + model_config["nheads"], + model_config["d_hid"], + model_config["nlayers"], + vocab=vocab, + dropout=hyperparameters["dropout"], + pad_token=model_settings["pad_token"], + pad_value=model_settings["pad_value"], + do_mvc=hyperparameters["GEPC"], + do_dab=True, + use_batch_labels=True, + num_batch_labels=len(set(adata.obs["batch_id"].tolist())), + domain_spec_batchnorm=model_settings["DSBN"], + n_input_bins=model_settings["n_input_bins"], + ecs_threshold=hyperparameters["ecs_thres"], + explicit_zero_prob=model_settings["explicit_zero_prob"], + use_fast_transformer=hyperparameters["fast_transformer"], + pre_norm=hyperparameters["pre_norm"], +) +scgpt.utils.load_pretrained( + model, torch.load(model_file, map_location=torch.device(device)), verbose=False +) +model.to(device) + +print("\n>>> Fine-tuning model...", flush=True) +criterion = scgpt.loss.masked_mse_loss +criterion_dab = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam( + model.parameters(), + lr=hyperparameters["lr"], + eps=1e-4 if hyperparameters["amp"] else 1e-8, +) +scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, 1, gamma=hyperparameters["schedule_ratio"] +) +scaler = torch.cuda.amp.GradScaler(enabled=hyperparameters["amp"]) + +best_val_loss = float("inf") +best_avg_bio = 0.0 +best_model = None + +for epoch in range(1, hyperparameters["epochs"] + 1): + epoch_start_time = time.time() + train_data_pt, valid_data_pt = prepare_data( + tokenized_train, + tokenized_valid, + train_batch_labels, + valid_batch_labels, + hyperparameters, + model_settings, + epoch, + ) + + train_loader = prepare_dataloader( + train_data_pt, + batch_size=hyperparameters["batch_size"], + shuffle=False, + intra_domain_shuffle=True, + drop_last=False, + num_workers=0, + per_seq_batch_sample=model_settings["per_seq_batch_sample"], + ) + + valid_loader = prepare_dataloader( + valid_data_pt, + batch_size=hyperparameters["batch_size"], + shuffle=False, + intra_domain_shuffle=False, + drop_last=False, + num_workers=0, + per_seq_batch_sample=model_settings["per_seq_batch_sample"], + ) + + train( + model, + train_loader, + scaler, + optimizer, + scheduler, + vocab, + criterion, + criterion_dab, + hyperparameters, + model_settings, + device, + epoch, + ) + + val_loss, val_mre = evaluate( + model, + valid_loader, + vocab, + criterion, + criterion_dab, + hyperparameters, + model_settings, + device, + ) + elapsed = time.time() - epoch_start_time + scgpt.logger.info("-" * 89) + scgpt.logger.info( + f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | " + f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}" + ) + scgpt.logger.info("-" * 89) + + if val_loss < best_val_loss: + best_val_loss = val_loss + best_model = copy.deepcopy(model) + best_model_epoch = epoch + scgpt.logger.info(f"Best model with score {best_val_loss:5.4f}") + + scheduler.step() + +print(f"Best model: Epoch {best_model_epoch}, Val loss {best_val_loss}") + +print("\n>>> Saving best model...", flush=True) +best_model_dir = tempfile.TemporaryDirectory() +shutil.copy(vocab_file, best_model_dir.name) +shutil.copy(model_config_file, best_model_dir.name) +torch.save(best_model.state_dict(), os.path.join(best_model_dir.name, "best_model.pt")) +print(f"Best model directory: '{best_model_dir.name}'", flush=True) + +print("\n>>> Embedding data...", flush=True) +embedded = scgpt.tasks.embed_data( + adata, + best_model_dir.name, + gene_col="feature_name", + batch_size=64, + use_fast_transformer=False, # Disable fast-attn as not installed + device=device, + return_new_adata=True, +) + +print("\n>>> Storing output...", flush=True) +output = ad.AnnData( + obs=adata.obs[[]], + var=adata.var[[]], + obsm={ + "X_emb": embedded.X, + }, + uns={ + "dataset_id": adata.uns["dataset_id"], + "normalization_id": adata.uns["normalization_id"], + "method_id": meta["name"], + }, +) +print(output) + +print("\n>>> Writing output to file...", flush=True) +print(f"Output H5AD file: '{par['output']}'", flush=True) +output.write_h5ad(par["output"], compression="gzip") + +print("\n>>> Cleaning up temporary directories...", flush=True) +if model_temp is not None: + model_temp.cleanup() +best_model_dir.cleanup() + +print("\n>>> Done!", flush=True) diff --git a/src/methods/scgpt/config.vsh.yaml b/src/methods/scgpt_zeroshot/config.vsh.yaml similarity index 90% rename from src/methods/scgpt/config.vsh.yaml rename to src/methods/scgpt_zeroshot/config.vsh.yaml index ad6825c6..23ea5931 100644 --- a/src/methods/scgpt/config.vsh.yaml +++ b/src/methods/scgpt_zeroshot/config.vsh.yaml @@ -1,11 +1,12 @@ __merge__: ../../api/base_method.yaml -name: scgpt -label: scGPT -summary: "A foundation model for single-cell biology" +name: scgpt_zeroshot +label: scGPT (zero shot) +summary: "A foundation model for single-cell biology (zero shot)" description: | scGPT is a foundation model for single-cell biology based on a generative pre-trained transformer and trained on a repository of over 33 million cells. + Here, we use zero-shot output from a pre-trained model to get an integrated embedding for the batch integration task. references: @@ -19,9 +20,9 @@ info: method_types: [embedding] preferred_normalization: counts variants: - scgpt_default: - scgpt_cp: - model: "scGPT_CP" + scgpt_zeroshot_default: + scgpt_zeroshot_cp: + model_name: "scGPT_CP" arguments: - name: --model_name diff --git a/src/methods/scgpt/script.py b/src/methods/scgpt_zeroshot/script.py similarity index 100% rename from src/methods/scgpt/script.py rename to src/methods/scgpt_zeroshot/script.py diff --git a/src/workflows/run_benchmark/config.vsh.yaml b/src/workflows/run_benchmark/config.vsh.yaml index 269f5337..5cd9b339 100644 --- a/src/workflows/run_benchmark/config.vsh.yaml +++ b/src/workflows/run_benchmark/config.vsh.yaml @@ -101,7 +101,8 @@ dependencies: - name: methods/scalex - name: methods/scanorama - name: methods/scanvi - - name: methods/scgpt + - name: methods/scgpt_finetuned + - name: methods/scgpt_zeroshot - name: methods/scimilarity - name: methods/scprint - name: methods/scvi diff --git a/src/workflows/run_benchmark/main.nf b/src/workflows/run_benchmark/main.nf index 4254d5fa..4c9a9d3c 100644 --- a/src/workflows/run_benchmark/main.nf +++ b/src/workflows/run_benchmark/main.nf @@ -29,7 +29,10 @@ methods = [ scalex, scanorama, scanvi, - scgpt.run( + scgpt_finetuned.run( + args: [model: file("s3://openproblems-work/cache/scGPT_human.zip")] + ), + scgpt_zeroshot.run( args: [model: file("s3://openproblems-work/cache/scGPT_human.zip")] ), scimilarity.run(