generated from openproblems-bio/task_template
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Move scgpt to scgpt_zeroshot * Create scgpt_finetuned component * Fix name in scgpt_zeroshot * Add scgpt_finetuned pre-processing * Add scgpt_finetuned fine-tuning and embedding * Add scgpt_finetuned to benchmark workflow * Style scgpt_finetuned files * Disable scgpt_finetuned tests Requires GPU * Remove docker run args from scgpt_finetuned * Add model arguments to scgpt_finetuned Fixed some things in scgpt_zeroshot... * Exclude scgpt_finetuned from full benchmark runs Need to sort out GPU infrastructure --------- Co-authored-by: Robrecht Cannoodt <[email protected]>
- Loading branch information
Showing
10 changed files
with
765 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
__merge__: ../../api/base_method.yaml | ||
|
||
name: scgpt_finetuned | ||
label: scGPT (fine-tuned) | ||
summary: "A foundation model for single-cell biology (fine-tuned)" | ||
description: | | ||
scGPT is a foundation model for single-cell biology based on a generative | ||
pre-trained transformer and trained on a repository of over 33 million cells. | ||
Here, we fine-tune the pre-trained model for the batch integration task. | ||
references: | ||
doi: | ||
- 10.1038/s41592-024-02201-0 | ||
links: | ||
documentation: https://scgpt.readthedocs.io/en/latest/ | ||
repository: https://github.com/bowang-lab/scGPT | ||
|
||
info: | ||
method_types: [embedding] | ||
preferred_normalization: counts | ||
variants: | ||
scgpt_finetuned_default: | ||
|
||
arguments: | ||
- name: --model_name | ||
type: string | ||
description: String giving the name of the scGPT model to use | ||
choices: ["scGPT_human", "scGPT_CP"] | ||
default: "scGPT_human" | ||
- name: --model | ||
type: file | ||
description: | | ||
Path to the directory containing the scGPT model specified by model_name | ||
or a .zip/.tar.gz archive to extract. If not given the model will be | ||
downloaded. | ||
required: false | ||
- name: --n_hvg | ||
type: integer | ||
default: 3000 | ||
description: Number of highly variable genes to use. | ||
|
||
resources: | ||
- type: python_script | ||
path: script.py | ||
- path: /src/utils/read_anndata_partial.py | ||
- path: scgpt_functions.py | ||
|
||
engines: | ||
- type: docker | ||
image: openproblems/base_pytorch_nvidia:1.0.0 | ||
# TODO: Try to find working installation of flash attention (flash-attn<1.0.5) | ||
setup: | ||
- type: python | ||
pypi: | ||
- gdown | ||
- scgpt # Install from PyPI to get dependencies | ||
- type: docker | ||
# Force re-installing from GitHub to get bug fixes | ||
run: pip install --upgrade --no-deps --force-reinstall git+https://github.com/bowang-lab/scGPT.git | ||
|
||
runners: | ||
- type: executable | ||
- type: nextflow | ||
directives: | ||
label: [midtime, midmem, midcpu, gpu] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,288 @@ | ||
import time | ||
import warnings | ||
|
||
import numpy as np | ||
import scgpt | ||
import torch | ||
|
||
|
||
def prepare_data( | ||
tokenized_train, | ||
tokenized_valid, | ||
train_batch_labels, | ||
valid_batch_labels, | ||
hyperparameters, | ||
model_settings, | ||
epoch, | ||
): | ||
masked_values_train = scgpt.tokenizer.random_mask_value( | ||
tokenized_train["values"], | ||
mask_ratio=hyperparameters["mask_ratio"], | ||
mask_value=model_settings["mask_value"], | ||
pad_value=model_settings["pad_value"], | ||
) | ||
masked_values_valid = scgpt.tokenizer.random_mask_value( | ||
tokenized_valid["values"], | ||
mask_ratio=hyperparameters["mask_ratio"], | ||
mask_value=model_settings["mask_value"], | ||
pad_value=model_settings["pad_value"], | ||
) | ||
scgpt.logger.info( | ||
f"Random masking at epoch {epoch:3d}," | ||
f"ratio of masked values in train: {(masked_values_train == model_settings['mask_value']).sum() / (masked_values_train - model_settings['pad_value']).count_nonzero():.4f}" | ||
) | ||
|
||
input_gene_ids_train, input_gene_ids_valid = ( | ||
tokenized_train["genes"], | ||
tokenized_valid["genes"], | ||
) | ||
input_values_train, input_values_valid = masked_values_train, masked_values_valid | ||
target_values_train, target_values_valid = ( | ||
tokenized_train["values"], | ||
tokenized_valid["values"], | ||
) | ||
|
||
tensor_batch_labels_train = torch.from_numpy(train_batch_labels).long() | ||
tensor_batch_labels_valid = torch.from_numpy(valid_batch_labels).long() | ||
|
||
if model_settings["per_seq_batch_sample"]: | ||
train_sort_ids = np.argsort(train_batch_labels) | ||
input_gene_ids_train = input_gene_ids_train[train_sort_ids] | ||
input_values_train = input_values_train[train_sort_ids] | ||
target_values_train = target_values_train[train_sort_ids] | ||
tensor_batch_labels_train = tensor_batch_labels_train[train_sort_ids] | ||
|
||
valid_sort_ids = np.argsort(valid_batch_labels) | ||
input_gene_ids_valid = input_gene_ids_valid[valid_sort_ids] | ||
input_values_valid = input_values_valid[valid_sort_ids] | ||
target_values_valid = target_values_valid[valid_sort_ids] | ||
tensor_batch_labels_valid = tensor_batch_labels_valid[valid_sort_ids] | ||
|
||
train_data_pt = { | ||
"gene_ids": input_gene_ids_train, | ||
"values": input_values_train, | ||
"target_values": target_values_train, | ||
"batch_labels": tensor_batch_labels_train, | ||
} | ||
valid_data_pt = { | ||
"gene_ids": input_gene_ids_valid, | ||
"values": input_values_valid, | ||
"target_values": target_values_valid, | ||
"batch_labels": tensor_batch_labels_valid, | ||
} | ||
|
||
return train_data_pt, valid_data_pt | ||
|
||
|
||
class SeqDataset(torch.utils.data.Dataset): | ||
def __init__(self, data): | ||
self.data = data | ||
|
||
def __len__(self): | ||
return self.data["gene_ids"].shape[0] | ||
|
||
def __getitem__(self, idx): | ||
return {k: v[idx] for k, v in self.data.items()} | ||
|
||
|
||
def prepare_dataloader( | ||
data_pt, | ||
batch_size, | ||
shuffle, | ||
intra_domain_shuffle, | ||
drop_last, | ||
num_workers, | ||
per_seq_batch_sample, | ||
): | ||
dataset = SeqDataset(data_pt) | ||
|
||
if per_seq_batch_sample: | ||
# Find the indices of samples in each seq batch | ||
subsets = [] | ||
batch_labels_array = data_pt["batch_labels"].numpy() | ||
for batch_label in np.unique(batch_labels_array): | ||
batch_indices = np.where(batch_labels_array == batch_label)[0].tolist() | ||
subsets.append(batch_indices) | ||
data_loader = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_sampler=scgpt.SubsetsBatchSampler( | ||
subsets, | ||
batch_size, | ||
intra_subset_shuffle=intra_domain_shuffle, | ||
inter_subset_shuffle=shuffle, | ||
drop_last=drop_last, | ||
), | ||
num_workers=num_workers, | ||
pin_memory=True, | ||
) | ||
return data_loader | ||
|
||
data_loader = torch.utils.data.DataLoader( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
drop_last=drop_last, | ||
num_workers=num_workers, | ||
pin_memory=True, | ||
) | ||
return data_loader | ||
|
||
|
||
def train( | ||
model, | ||
loader, | ||
scaler, | ||
optimizer, | ||
scheduler, | ||
vocab, | ||
criterion, | ||
criterion_dab, | ||
hyperparameters, | ||
model_settings, | ||
device, | ||
epoch, | ||
): | ||
model.train() | ||
|
||
total_loss, total_mse, total_gepc = 0.0, 0.0, 0.0 | ||
total_error = 0.0 | ||
log_interval = hyperparameters["log_interval"] | ||
start_time = time.time() | ||
|
||
num_batches = len(loader) | ||
for batch, batch_data in enumerate(loader): | ||
input_gene_ids = batch_data["gene_ids"].to(device) | ||
input_values = batch_data["values"].to(device) | ||
target_values = batch_data["target_values"].to(device) | ||
batch_labels = batch_data["batch_labels"].to(device) | ||
|
||
src_key_padding_mask = input_gene_ids.eq(vocab[model_settings["pad_token"]]) | ||
with torch.cuda.amp.autocast(enabled=hyperparameters["amp"]): | ||
output_dict = model( | ||
input_gene_ids, | ||
input_values, | ||
src_key_padding_mask=src_key_padding_mask, | ||
batch_labels=batch_labels if model_settings["DSBN"] else None, | ||
MVC=hyperparameters["GEPC"], | ||
ECS=hyperparameters["ecs_thres"] > 0, | ||
) | ||
|
||
masked_positions = input_values.eq( | ||
model_settings["mask_value"] | ||
) # the postions to predict | ||
loss = loss_mse = criterion( | ||
output_dict["mlm_output"], target_values, masked_positions | ||
) | ||
if model_settings["explicit_zero_prob"]: | ||
loss_zero_log_prob = scgpt.loss.criterion_neg_log_bernoulli( | ||
output_dict["mlm_zero_probs"], target_values, masked_positions | ||
) | ||
loss = loss + loss_zero_log_prob | ||
if hyperparameters["GEPC"]: | ||
loss_gepc = criterion( | ||
output_dict["mvc_output"], target_values, masked_positions | ||
) | ||
loss = loss + loss_gepc | ||
if hyperparameters["GEPC"] and model_settings["explicit_zero_prob"]: | ||
loss_gepc_zero_log_prob = scgpt.loss.criterion_neg_log_bernoulli( | ||
output_dict["mvc_zero_probs"], target_values, masked_positions | ||
) | ||
loss = loss + loss_gepc_zero_log_prob | ||
if hyperparameters["ecs_thres"] > 0: | ||
loss_ecs = 10 * output_dict["loss_ecs"] | ||
loss = loss + loss_ecs | ||
loss_dab = criterion_dab(output_dict["dab_output"], batch_labels) | ||
loss = loss + hyperparameters["dab_weight"] * loss_dab | ||
|
||
model.zero_grad() | ||
scaler.scale(loss).backward() | ||
scaler.unscale_(optimizer) | ||
with warnings.catch_warnings(record=True) as w: | ||
warnings.filterwarnings("always") | ||
torch.nn.utils.clip_grad_norm_( | ||
model.parameters(), | ||
1.0, | ||
error_if_nonfinite=False if scaler.is_enabled() else True, | ||
) | ||
if len(w) > 0: | ||
scgpt.logger.warning( | ||
f"Found infinite gradient. This may be caused by the gradient " | ||
f"scaler. The current scale is {scaler.get_scale()}. This warning " | ||
"can be ignored if no longer occurs after autoscaling of the scaler." | ||
) | ||
scaler.step(optimizer) | ||
scaler.update() | ||
|
||
with torch.no_grad(): | ||
mre = scgpt.loss.masked_relative_error( | ||
output_dict["mlm_output"], target_values, masked_positions | ||
) | ||
|
||
total_loss += loss.item() | ||
total_mse += loss_mse.item() | ||
total_gepc += loss_gepc.item() if hyperparameters["GEPC"] else 0.0 | ||
total_error += mre.item() | ||
if batch % log_interval == 0 and batch > 0: | ||
lr = scheduler.get_last_lr()[0] | ||
ms_per_batch = (time.time() - start_time) * 1000 / log_interval | ||
cur_loss = total_loss / log_interval | ||
cur_mse = total_mse / log_interval | ||
cur_gepc = total_gepc / log_interval if hyperparameters["GEPC"] else 0.0 | ||
cur_error = total_error / log_interval | ||
scgpt.logger.info( | ||
f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | " | ||
f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | " | ||
f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} | mre {cur_error:5.2f} |" | ||
+ (f"gepc {cur_gepc:5.2f} |" if hyperparameters["GEPC"] else "") | ||
) | ||
total_loss = 0 | ||
total_mse = 0 | ||
total_gepc = 0 | ||
total_error = 0 | ||
start_time = time.time() | ||
|
||
|
||
def evaluate( | ||
model, | ||
loader, | ||
vocab, | ||
criterion, | ||
criterion_dab, | ||
hyperparameters, | ||
model_settings, | ||
device, | ||
): | ||
model.eval() | ||
total_loss = 0.0 | ||
total_error = 0.0 | ||
total_dab = 0.0 | ||
total_num = 0 | ||
with torch.no_grad(): | ||
for batch_data in loader: | ||
input_gene_ids = batch_data["gene_ids"].to(device) | ||
input_values = batch_data["values"].to(device) | ||
target_values = batch_data["target_values"].to(device) | ||
batch_labels = batch_data["batch_labels"].to(device) | ||
|
||
src_key_padding_mask = input_gene_ids.eq(vocab[model_settings["pad_token"]]) | ||
with torch.cuda.amp.autocast(enabled=hyperparameters["amp"]): | ||
output_dict = model( | ||
input_gene_ids, | ||
input_values, | ||
src_key_padding_mask=src_key_padding_mask, | ||
batch_labels=batch_labels if model_settings["DSBN"] else None, | ||
) | ||
output_values = output_dict["mlm_output"] | ||
|
||
masked_positions = input_values.eq(model_settings["mask_value"]) | ||
loss = criterion(output_values, target_values, masked_positions) | ||
loss_dab = criterion_dab(output_dict["dab_output"], batch_labels) | ||
|
||
total_loss += loss.item() * len(input_gene_ids) | ||
total_error += scgpt.loss.masked_relative_error( | ||
output_values, target_values, masked_positions | ||
).item() * len(input_gene_ids) | ||
total_dab += loss_dab.item() * len(input_gene_ids) | ||
total_num += len(input_gene_ids) | ||
|
||
return total_loss / total_num, total_error / total_num |
Oops, something went wrong.