Skip to content

Commit

Permalink
scgpt parametrized
Browse files Browse the repository at this point in the history
  • Loading branch information
janursa committed Sep 9, 2024
1 parent 3379386 commit ce35bda
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 27 deletions.
12 changes: 12 additions & 0 deletions src/methods/single_omics/scgpt/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ functionality:
direction: input
example: resources_test/supplementary/finetuned_scGPT_adamson/vocab.json
default: resources_test/supplementary/finetuned_scGPT_adamson/vocab.json
- name: --n_bins
type: integer
direction: input
default: 51
- name: --batch_size
type: integer
direction: input
default: 16
- name: --condition
type: string
direction: input
default: cell_type

resources:
- type: python_script
Expand Down
54 changes: 27 additions & 27 deletions src/methods/single_omics/scgpt/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,15 @@

from scipy.sparse import issparse
import scipy as sp
import numpy as np
from einops import rearrange
from torch.nn.functional import softmax
from tqdm import tqdm
import pandas as pd

from torchtext.vocab import Vocab
from torchtext._torchtext import (
Vocab as VocabPybind,
)

sys.path.insert(0, "../")

import scgpt as scg
from scgpt.tasks import GeneEmbedding
from scgpt.tokenizer.gene_tokenizer import GeneVocab
Expand All @@ -44,30 +41,31 @@

## VIASH START
par = {
'multiomics_rna': 'resources_test/grn-benchmark/multiomics_rna.h5ad',
'tf_all': 'resources_test/prior/tf_all.csv',
'multiomics_rna': '../input/resources_test/grn-benchmark/multiomics_rna.h5ad',
'tf_all': '../input/resources_test/prior/tf_all.csv',
'prediction': 'output/prediction_scgpt.csv',
'max_n_links': 50000,
'model_file': 'resources_test/supplementary/finetuned_scGPT_adamson/best_model.pt',
'model_config_file': 'resources_test/supplementary/finetuned_scGPT_adamson/args.json',
'vocab_file': 'resources_test/supplementary/finetuned_scGPT_adamson/vocab.json'
'model_file': '../input/resources_test/supplementary/finetuned_scGPT_adamson/best_model.pt',
'model_config_file': '../input/resources_test/supplementary/finetuned_scGPT_adamson/args.json',
'vocab_file': '../input/resources_test/supplementary/finetuned_scGPT_adamson/vocab.json',
'n_bins': 51,
'batch_size': 16,
'condition': 'cell_type'
}
## VIASH END


# Load list of putative TFs
tf_all = np.loadtxt(par['tf_all'], dtype=str)

set_seed(42)
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
n_hvg = 1200
n_bins = 51

mask_value = -1
pad_value = -2
batch_size = 16
batch_size = par['batch_size']
num_attn_layers = 11
n_input_bins = n_bins
n_input_bins = par['n_bins']


model_config_file = par['model_config_file']
Expand Down Expand Up @@ -134,7 +132,7 @@

print('Process rna-seq file')
import scanpy as sc
adata = sc.read(par['multiomics_rna.h5ad'])
adata = sc.read(par['multiomics_rna'])
adata.obs["celltype"] = adata.obs["cell_type"].astype("category")
adata.obs["str_batch"] = adata.obs["donor_id"].astype(str)
data_is_raw = False
Expand All @@ -158,16 +156,16 @@
)
preprocessor(adata, batch_key="str_batch")

print('Subsetting to HVGs')
sc.pp.highly_variable_genes(
adata,
layer=None,
n_top_genes=n_hvg,
batch_key="str_batch",
flavor="seurat_v3" if data_is_raw else "cell_ranger",
subset=False,
)
adata = adata[:, adata.var["highly_variable"]].copy()
# print('Subsetting to HVGs')
# sc.pp.highly_variable_genes(
# adata,
# layer=None,
# n_top_genes=n_hvg,
# batch_key="str_batch",
# flavor="seurat_v3" if data_is_raw else "cell_ranger",
# subset=False,
# )
# adata = adata[:, adata.var["highly_variable"]].copy()


input_layer_key = "X_binned"
Expand Down Expand Up @@ -197,7 +195,7 @@

src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])

condition_ids = np.array(adata.obs["cell_type"].tolist())
condition_ids = np.array(adata.obs[par['condition']].tolist())

torch.cuda.empty_cache()
dict_sum_condition = {}
Expand Down Expand Up @@ -253,7 +251,7 @@
else:
dict_sum_condition[c] += outputs[index, :, :]
print('Average across groups of cell types')
groups = adata.obs.groupby('cell_type').groups
groups = adata.obs.groupby([par['condition']]).groups
dict_sum_condition_mean = dict_sum_condition.copy()
for i in groups.keys():
dict_sum_condition_mean[i] = dict_sum_condition_mean[i]/len(groups[i])
Expand All @@ -265,6 +263,8 @@
print('Format as df, melt, and subset')
net = pd.DataFrame(mean_grn, columns=gene_names, index=gene_names)
net = net.iloc[1:, 1:]

tf_all = np.intersect1d(tf_all, gene_names)
net = net[tf_all]

net_melted = net.reset_index() # Move index to a column for melting
Expand Down

0 comments on commit ce35bda

Please sign in to comment.