Skip to content

Commit

Permalink
Add fine-tuned scGPT (#17)
Browse files Browse the repository at this point in the history
* Move scgpt to scgpt_zeroshot

* Create scgpt_finetuned component

* Fix name in scgpt_zeroshot

* Add scgpt_finetuned pre-processing

* Add scgpt_finetuned fine-tuning and embedding

* Add scgpt_finetuned to benchmark workflow

* Style scgpt_finetuned files

* Disable scgpt_finetuned tests

Requires GPU

* Remove docker run args from scgpt_finetuned

* Add model arguments to scgpt_finetuned

Fixed some things in scgpt_zeroshot...

* Exclude scgpt_finetuned from full benchmark runs

Need to sort out GPU infrastructure

---------

Co-authored-by: Robrecht Cannoodt <[email protected]>
  • Loading branch information
lazappi and rcannood authored Jan 17, 2025
1 parent 058ebf0 commit 6191552
Show file tree
Hide file tree
Showing 10 changed files with 765 additions and 10 deletions.
2 changes: 1 addition & 1 deletion scripts/run_benchmark/run_full_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ input_states: resources/datasets/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["uce"]}'
settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}'
HERE

# run the benchmark
Expand Down
1 change: 1 addition & 0 deletions scripts/run_benchmark/run_full_seqeracloud.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ input_states: s3://openproblems-data/resources/task_batch_integration/datasets/*
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["scgpt_finetuned"]}'
HERE

tw launch https://github.com/openproblems-bio/task_batch_integration.git \
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_benchmark/run_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ input_states: resources_test/task_batch_integration/**/state.yaml
rename_keys: 'input_dataset:output_dataset;input_solution:output_solution'
output_state: "state.yaml"
publish_dir: "$publish_dir"
settings: '{"methods_exclude": ["uce"]}'
settings: '{"methods_exclude": ["uce", "scgpt_finetuned"]}'
HERE

nextflow run . \
Expand Down
65 changes: 65 additions & 0 deletions src/methods/scgpt_finetuned/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
__merge__: ../../api/base_method.yaml

name: scgpt_finetuned
label: scGPT (fine-tuned)
summary: "A foundation model for single-cell biology (fine-tuned)"
description: |
scGPT is a foundation model for single-cell biology based on a generative
pre-trained transformer and trained on a repository of over 33 million cells.
Here, we fine-tune the pre-trained model for the batch integration task.
references:
doi:
- 10.1038/s41592-024-02201-0
links:
documentation: https://scgpt.readthedocs.io/en/latest/
repository: https://github.com/bowang-lab/scGPT

info:
method_types: [embedding]
preferred_normalization: counts
variants:
scgpt_finetuned_default:

arguments:
- name: --model_name
type: string
description: String giving the name of the scGPT model to use
choices: ["scGPT_human", "scGPT_CP"]
default: "scGPT_human"
- name: --model
type: file
description: |
Path to the directory containing the scGPT model specified by model_name
or a .zip/.tar.gz archive to extract. If not given the model will be
downloaded.
required: false
- name: --n_hvg
type: integer
default: 3000
description: Number of highly variable genes to use.

resources:
- type: python_script
path: script.py
- path: /src/utils/read_anndata_partial.py
- path: scgpt_functions.py

engines:
- type: docker
image: openproblems/base_pytorch_nvidia:1.0.0
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5)
setup:
- type: python
pypi:
- gdown
- scgpt # Install from PyPI to get dependencies
- type: docker
# Force re-installing from GitHub to get bug fixes
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git

runners:
- type: executable
- type: nextflow
directives:
label: [midtime, midmem, midcpu, gpu]
288 changes: 288 additions & 0 deletions src/methods/scgpt_finetuned/scgpt_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import time
import warnings

import numpy as np
import scgpt
import torch


def prepare_data(
tokenized_train,
tokenized_valid,
train_batch_labels,
valid_batch_labels,
hyperparameters,
model_settings,
epoch,
):
masked_values_train = scgpt.tokenizer.random_mask_value(
tokenized_train["values"],
mask_ratio=hyperparameters["mask_ratio"],
mask_value=model_settings["mask_value"],
pad_value=model_settings["pad_value"],
)
masked_values_valid = scgpt.tokenizer.random_mask_value(
tokenized_valid["values"],
mask_ratio=hyperparameters["mask_ratio"],
mask_value=model_settings["mask_value"],
pad_value=model_settings["pad_value"],
)
scgpt.logger.info(
f"Random masking at epoch {epoch:3d},"
f"ratio of masked values in train: {(masked_values_train == model_settings['mask_value']).sum() / (masked_values_train - model_settings['pad_value']).count_nonzero():.4f}"
)

input_gene_ids_train, input_gene_ids_valid = (
tokenized_train["genes"],
tokenized_valid["genes"],
)
input_values_train, input_values_valid = masked_values_train, masked_values_valid
target_values_train, target_values_valid = (
tokenized_train["values"],
tokenized_valid["values"],
)

tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long()
tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long()

if model_settings["per_seq_batch_sample"]:
train_sort_ids = np.argsort(train_batch_labels)
input_gene_ids_train = input_gene_ids_train[train_sort_ids]
input_values_train = input_values_train[train_sort_ids]
target_values_train = target_values_train[train_sort_ids]
tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids]

valid_sort_ids = np.argsort(valid_batch_labels)
input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids]
input_values_valid = input_values_valid[valid_sort_ids]
target_values_valid = target_values_valid[valid_sort_ids]
tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids]

train_data_pt = {
"gene_ids": input_gene_ids_train,
"values": input_values_train,
"target_values": target_values_train,
"batch_labels": tensor_batch_labels_train,
}
valid_data_pt = {
"gene_ids": input_gene_ids_valid,
"values": input_values_valid,
"target_values": target_values_valid,
"batch_labels": tensor_batch_labels_valid,
}

return train_data_pt, valid_data_pt


class SeqDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data

def __len__(self):
return self.data["gene_ids"].shape[0]

def __getitem__(self, idx):
return {k: v[idx] for k, v in self.data.items()}


def prepare_dataloader(
data_pt,
batch_size,
shuffle,
intra_domain_shuffle,
drop_last,
num_workers,
per_seq_batch_sample,
):
dataset = SeqDataset(data_pt)

if per_seq_batch_sample:
# Find the indices of samples in each seq batch
subsets = []
batch_labels_array = data_pt["batch_labels"].numpy()
for batch_label in np.unique(batch_labels_array):
batch_indices = np.where(batch_labels_array == batch_label)[0].tolist()
subsets.append(batch_indices)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_sampler=scgpt.SubsetsBatchSampler(
subsets,
batch_size,
intra_subset_shuffle=intra_domain_shuffle,
inter_subset_shuffle=shuffle,
drop_last=drop_last,
),
num_workers=num_workers,
pin_memory=True,
)
return data_loader

data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
num_workers=num_workers,
pin_memory=True,
)
return data_loader


def train(
model,
loader,
scaler,
optimizer,
scheduler,
vocab,
criterion,
criterion_dab,
hyperparameters,
model_settings,
device,
epoch,
):
model.train()

total_loss, total_mse, total_gepc = 0.0, 0.0, 0.0
total_error = 0.0
log_interval = hyperparameters["log_interval"]
start_time = time.time()

num_batches = len(loader)
for batch, batch_data in enumerate(loader):
input_gene_ids = batch_data["gene_ids"].to(device)
input_values = batch_data["values"].to(device)
target_values = batch_data["target_values"].to(device)
batch_labels = batch_data["batch_labels"].to(device)

src_key_padding_mask = input_gene_ids.eq(vocab[model_settings["pad_token"]])
with torch.cuda.amp.autocast(enabled=hyperparameters["amp"]):
output_dict = model(
input_gene_ids,
input_values,
src_key_padding_mask=src_key_padding_mask,
batch_labels=batch_labels if model_settings["DSBN"] else None,
MVC=hyperparameters["GEPC"],
ECS=hyperparameters["ecs_thres"] > 0,
)

masked_positions = input_values.eq(
model_settings["mask_value"]
) # the postions to predict
loss = loss_mse = criterion(
output_dict["mlm_output"], target_values, masked_positions
)
if model_settings["explicit_zero_prob"]:
loss_zero_log_prob = scgpt.loss.criterion_neg_log_bernoulli(
output_dict["mlm_zero_probs"], target_values, masked_positions
)
loss = loss + loss_zero_log_prob
if hyperparameters["GEPC"]:
loss_gepc = criterion(
output_dict["mvc_output"], target_values, masked_positions
)
loss = loss + loss_gepc
if hyperparameters["GEPC"] and model_settings["explicit_zero_prob"]:
loss_gepc_zero_log_prob = scgpt.loss.criterion_neg_log_bernoulli(
output_dict["mvc_zero_probs"], target_values, masked_positions
)
loss = loss + loss_gepc_zero_log_prob
if hyperparameters["ecs_thres"] > 0:
loss_ecs = 10 * output_dict["loss_ecs"]
loss = loss + loss_ecs
loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)
loss = loss + hyperparameters["dab_weight"] * loss_dab

model.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings("always")
torch.nn.utils.clip_grad_norm_(
model.parameters(),
1.0,
error_if_nonfinite=False if scaler.is_enabled() else True,
)
if len(w) > 0:
scgpt.logger.warning(
f"Found infinite gradient. This may be caused by the gradient "
f"scaler. The current scale is {scaler.get_scale()}. This warning "
"can be ignored if no longer occurs after autoscaling of the scaler."
)
scaler.step(optimizer)
scaler.update()

with torch.no_grad():
mre = scgpt.loss.masked_relative_error(
output_dict["mlm_output"], target_values, masked_positions
)

total_loss += loss.item()
total_mse += loss_mse.item()
total_gepc += loss_gepc.item() if hyperparameters["GEPC"] else 0.0
total_error += mre.item()
if batch % log_interval == 0 and batch > 0:
lr = scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
cur_mse = total_mse / log_interval
cur_gepc = total_gepc / log_interval if hyperparameters["GEPC"] else 0.0
cur_error = total_error / log_interval
scgpt.logger.info(
f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} | mre {cur_error:5.2f} |"
+ (f"gepc {cur_gepc:5.2f} |" if hyperparameters["GEPC"] else "")
)
total_loss = 0
total_mse = 0
total_gepc = 0
total_error = 0
start_time = time.time()


def evaluate(
model,
loader,
vocab,
criterion,
criterion_dab,
hyperparameters,
model_settings,
device,
):
model.eval()
total_loss = 0.0
total_error = 0.0
total_dab = 0.0
total_num = 0
with torch.no_grad():
for batch_data in loader:
input_gene_ids = batch_data["gene_ids"].to(device)
input_values = batch_data["values"].to(device)
target_values = batch_data["target_values"].to(device)
batch_labels = batch_data["batch_labels"].to(device)

src_key_padding_mask = input_gene_ids.eq(vocab[model_settings["pad_token"]])
with torch.cuda.amp.autocast(enabled=hyperparameters["amp"]):
output_dict = model(
input_gene_ids,
input_values,
src_key_padding_mask=src_key_padding_mask,
batch_labels=batch_labels if model_settings["DSBN"] else None,
)
output_values = output_dict["mlm_output"]

masked_positions = input_values.eq(model_settings["mask_value"])
loss = criterion(output_values, target_values, masked_positions)
loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

total_loss += loss.item() * len(input_gene_ids)
total_error += scgpt.loss.masked_relative_error(
output_values, target_values, masked_positions
).item() * len(input_gene_ids)
total_dab += loss_dab.item() * len(input_gene_ids)
total_num += len(input_gene_ids)

return total_loss / total_num, total_error / total_num
Loading

0 comments on commit 6191552

Please sign in to comment.