Skip to content

Commit

Permalink
update novel
Browse files Browse the repository at this point in the history
  • Loading branch information
rcannood committed Jan 8, 2025
1 parent 4e15a89 commit 1baca75
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 61 deletions.
15 changes: 13 additions & 2 deletions scripts/create_datasets/test_resources.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,31 @@ nextflow run . \
echo "Run one method"

for name in bmmc_cite/normal bmmc_cite/swap bmmc_multiome/normal bmmc_multiome/swap; do
echo "Run KNN on $name"
viash run src/methods/knnr_py/config.vsh.yaml -- \
--input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \
--input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \
--input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \
--output $OUTPUT_DIR/openproblems_neurips2021/$name/prediction.h5ad

# pre-train simple_mlp
rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/
echo "pre-train simple_mlp on $name"
[ -d $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/ ] && rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/
mkdir -p $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/
viash run src/methods/simple_mlp/train/config.vsh.yaml -- \
--input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \
--input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \
--input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \
--output $OUTPUT_DIR/openproblems_neurips2021/$name/models/simple_mlp/

echo "pre-train novel on $name"
[ -d $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/ ] && rm -r $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/
mkdir -p $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel/
viash run src/methods/novel/train/config.vsh.yaml -- \
--input_train_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod1.h5ad \
--input_train_mod2 $OUTPUT_DIR/openproblems_neurips2021/$name/train_mod2.h5ad \
--input_test_mod1 $OUTPUT_DIR/openproblems_neurips2021/$name/test_mod1.h5ad \
--output $OUTPUT_DIR/openproblems_neurips2021/$name/models/novel

done

# only run this if you have access to the openproblems-data bucket
Expand Down
12 changes: 6 additions & 6 deletions src/methods/novel/predict/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__merge__: ../../../api/comp_method_predict.yaml
name: novel_predict
arguments:
- name: "--input_transform"
type: file
direction: input
required: false
example: "lsi_transformer.pickle"

info:
test_setup:
with_model:
input_model: resources_test/task_predict_modality/openproblems_neurips2021/bmmc_cite/swap/models/novel

resources:
- type: python_script
path: script.py
Expand Down
29 changes: 14 additions & 15 deletions src/methods/novel/predict/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,27 @@


## VIASH START

par = {
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod2.h5ad',
'input_test_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/test_mod1.h5ad',
'input_model': 'resources_test/predict_modality/neurips2021_bmmc_cite/model.pt',
'input_transform': 'transformer.pickle'
}
meta = {
'resources_dir': 'src/tasks/predict_modality/methods/novel',
'functionality_name': '171129'
}
## VIASH END

sys.path.append(meta['resources_dir'])
from helper_functions import ModelRegressionAtac2Gex, ModelRegressionAdt2Gex, ModelRegressionGex2Adt, ModelRegressionGex2Atac, ModalityMatchingDataset

input_model = f"{par['input_model']}/tensor.pt"
input_transform = f"{par['input_model']}/transform.pkl"
input_h5ad = f"{par['input_model']}/train_mod2.h5ad"

print("Load data", flush=True)

input_test_mod1 = ad.read_h5ad(par['input_test_mod1'])
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])
input_train_mod2 = ad.read_h5ad(input_h5ad)

mod1 = input_test_mod1.uns['modality']
mod2 = input_train_mod2.uns['modality']
Expand All @@ -46,48 +47,46 @@

input_test_mod1.X = input_test_mod1.layers['normalized'].tocsr()

# Remove vars that were removed from training set. Mostlyy only applicable for testing.
# Remove vars that were removed from training set. Mostly only applicable for testing.
if input_train_mod2.uns.get("removed_vars"):
rem_var = input_train_mod2.uns["removed_vars"]
input_test_mod1 = input_test_mod1[:, ~input_test_mod1.var_names.isin(rem_var)]

del input_train_mod2


model_fp = par['input_model']

print("Start predict", flush=True)

if mod1 == 'GEX' and mod2 == 'ADT':
model = ModelRegressionGex2Adt(n_vars_mod1,n_vars_mod2)
weight = torch.load(model_fp, map_location='cpu')
with open(par['input_transform'], 'rb') as f:
weight = torch.load(input_model, map_location='cpu')
with open(input_transform, 'rb') as f:
lsi_transformer_gex = pickle.load(f)

model.load_state_dict(weight)
input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1)

elif mod1 == 'GEX' and mod2 == 'ATAC':
model = ModelRegressionGex2Atac(n_vars_mod1,n_vars_mod2)
weight = torch.load(model_fp, map_location='cpu')
with open(par['input_transform'], 'rb') as f:
weight = torch.load(input_model, map_location='cpu')
with open(input_transform, 'rb') as f:
lsi_transformer_gex = pickle.load(f)

model.load_state_dict(weight)
input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1)

elif mod1 == 'ATAC' and mod2 == 'GEX':
model = ModelRegressionAtac2Gex(n_vars_mod1,n_vars_mod2)
weight = torch.load(model_fp, map_location='cpu')
with open(par['input_transform'], 'rb') as f:
weight = torch.load(input_model, map_location='cpu')
with open(input_transform, 'rb') as f:
lsi_transformer_gex = pickle.load(f)

model.load_state_dict(weight)
input_test_mod1_ = lsi_transformer_gex.transform(input_test_mod1)

elif mod1 == 'ADT' and mod2 == 'GEX':
model = ModelRegressionAdt2Gex(n_vars_mod1,n_vars_mod2)
weight = torch.load(model_fp, map_location='cpu')
weight = torch.load(input_model, map_location='cpu')

model.load_state_dict(weight)
input_test_mod1_ = input_test_mod1.to_df()
Expand All @@ -111,7 +110,7 @@
shape=outputs.shape,
uns={
'dataset_id': input_test_mod1.uns['dataset_id'],
'method_id': meta['functionality_name'],
'method_id': meta['name'],
},
)
adata.write_h5ad(par['output'], compression = "gzip")
Expand Down
10 changes: 2 additions & 8 deletions src/methods/novel/run/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,10 @@ workflow run_wf {
output_ch = input_ch
| novel_train.run(
fromState: ["input_train_mod1", "input_train_mod2"],
toState: ["input_model": "output", "input_transform": "output_transform", "output_train_mod2": "output_train_mod2"]
toState: ["input_model": "output"]
)
| novel_predict.run(
fromState: { id, state ->
[
"input_train_mod2": state.output_train_mod2,
"input_test_mod1": state.input_test_mod1,
"input_model": state.input_model,
"input_transform": state.input_transform,
"output": state.output]},
fromState: ["input_test_mod1", "input_train_mod2", "input_model"],
toState: ["output": "output"]
)

Expand Down
13 changes: 0 additions & 13 deletions src/methods/novel/train/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -1,18 +1,5 @@
__merge__: ../../../api/comp_method_train.yaml
name: novel_train
arguments:
- name: --output_transform
type: file
description: "The output transform file"
required: false
default: "lsi_transformer.pickle"
direction: output
- name: --output_train_mod2
type: file
description: copy of the input with model dim in `.uns`
direction: output
default: "train_mod2.h5ad"
required: false
resources:
- path: script.py
type: python_script
Expand Down
45 changes: 28 additions & 17 deletions src/methods/novel/train/script.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import sys
import os
import math
import numpy as np

import torch
from torch.utils.data import DataLoader
Expand All @@ -17,26 +20,21 @@


## VIASH START

par = {
'input_train_mod1': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod1.h5ad',
'input_train_mod2': 'resources_test/predict_modality/openproblems_neurips2021/bmmc_cite/normal/train_mod2.h5ad',
'output_train_mod2': 'train_mod2.h5ad',
'output': 'model.pt'
'input_train_mod1': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/train_mod1.h5ad',
'input_train_mod2': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/train_mod2.h5ad',
'output': 'resources_test/task_predict_modality/openproblems_neurips2021/bmmc_multiome/normal/models/novel'
}

meta = {
'resources_dir': 'src/tasks/predict_modality/methods/novel',
'resources_dir': 'src/methods/novel',
}
## VIASH END


sys.path.append(meta['resources_dir'])
from helper_functions import train_and_valid, lsiTransformer, ModalityMatchingDataset
from helper_functions import ModelRegressionAtac2Gex, ModelRegressionAdt2Gex, ModelRegressionGex2Adt, ModelRegressionGex2Atac

print('Load data', flush=True)

input_train_mod1 = ad.read_h5ad(par['input_train_mod1'])
input_train_mod2 = ad.read_h5ad(par['input_train_mod2'])

Expand All @@ -53,8 +51,6 @@
del input_train_mod2

print('Start train', flush=True)


# Check for zero divide
zero_row = input_train_mod1.X.sum(axis=0) == 0

Expand All @@ -75,8 +71,13 @@

# reproduce train/test split from phase 1
batch = input_train_mod1.obs["batch"]
train_ix = [ k for k,v in enumerate(batch) if v not in {'s1d2', 's3d7'} ]
test_ix = [ k for k,v in enumerate(batch) if v in {'s1d2', 's3d7'} ]
test_batches = {'s1d2', 's3d7'}
# if none of phase1_batch is in batch, sample 25% of batch categories rounded up
if len(test_batches.intersection(set(batch))) == 0:
all_batches = batch.cat.categories.tolist()
test_batches = set(np.random.choice(all_batches, math.ceil(len(all_batches) * 0.25), replace=False))
train_ix = [ k for k,v in enumerate(batch) if v not in test_batches ]
test_ix = [ k for k,v in enumerate(batch) if v in test_batches ]

train_mod1 = input_train_mod1_df.iloc[train_ix, :]
train_mod2 = input_train_mod2_df.iloc[train_ix, :]
Expand Down Expand Up @@ -134,14 +135,24 @@
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00001806762345275399, weight_decay=0.0004084171379280058)

loss_fn = torch.nn.MSELoss()
train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, par['output'], device)

# create dir for par['output']
os.makedirs(par['output'], exist_ok=True)

# determine filenames
output_model = f"{par['output']}/tensor.pt"
output_h5ad = f"{par['output']}/train_mod2.h5ad"
output_transform = f"{par['output']}/transform.pkl"

# train model
train_and_valid(model, optimizer, loss_fn, dataloader_train, dataloader_test, output_model, device)

# Add model dim for use in predict part
adata.uns["model_dim"] = {"mod1": n_vars_mod1, "mod2": n_vars_mod2}
if rem_var:
if rem_var is not None:
adata.uns["removed_vars"] = [rem_var[0]]
adata.write_h5ad(par['output_train_mod2'], compression="gzip")
adata.write_h5ad(output_h5ad, compression="gzip")

if mod1 != 'ADT':
with open(par['output_transform'], 'wb') as f:
with open(output_transform, 'wb') as f:
pickle.dump(lsi_transformer_gex, f)

0 comments on commit 1baca75

Please sign in to comment.