-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathpredict.py
executable file
·174 lines (158 loc) · 6.54 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/env python
import click as ck
import numpy as np
import pandas as pd
from tensorflow.keras.models import load_model
from subprocess import Popen, PIPE
import time
from utils import Ontology, NAMESPACES
from aminoacids import to_onehot
import gzip
MAXLEN = 2000
@ck.command()
@ck.option('--in-file', '-if', help='Input FASTA file', required=True)
@ck.option('--out-file', '-of', default='results.tsv.gz', help='Output result file')
@ck.option('--go-file', '-gf', default='data/go.obo', help='Gene Ontology')
@ck.option('--model-file', '-mf', default='data/model.h5', help='Tensorflow model file')
@ck.option('--terms-file', '-tf', default='data/terms.pkl', help='List of predicted terms')
@ck.option('--annotations-file', '-tf', default='data/train_data.pkl', help='Experimental annotations')
@ck.option('--chunk-size', '-cs', default=1000, help='Number of sequences to read at a time')
@ck.option('--diamond-file', '-df', default='data/test_diamond.res', help='Diamond Mapping file')
@ck.option('--threshold', '-t', default=0.1, help='Prediction threshold')
@ck.option('--batch-size', '-bs', default=32, help='Batch size for prediction model')
@ck.option('--alpha', '-a', default=0.5, help='Alpha weight parameter')
def main(in_file, out_file, go_file, model_file, terms_file, annotations_file,
chunk_size, diamond_file, threshold, batch_size, alpha):
# Load GO and read list of all terms
go = Ontology(go_file, with_rels=True)
terms_df = pd.read_pickle(terms_file)
terms = terms_df['terms'].values.flatten()
# Read known experimental annotations
annotations = {}
df = pd.read_pickle(annotations_file)
for row in df.itertuples():
annotations[row.proteins] = set(row.prop_annotations)
go.calculate_ic(annotations.values())
diamond_preds = {}
mapping = {}
with gzip.open(diamond_file, 'rt') as f:
for line in f:
it = line.strip().split()
if it[0] not in mapping:
mapping[it[0]] = {}
mapping[it[0]][it[1]] = float(it[2])
for prot_id, sim_prots in mapping.items():
annots = {}
allgos = set()
total_score = 0.0
for p_id, score in sim_prots.items():
allgos |= annotations[p_id]
total_score += score
allgos = list(sorted(allgos))
sim = np.zeros(len(allgos), dtype=np.float32)
for j, go_id in enumerate(allgos):
s = 0.0
for p_id, score in sim_prots.items():
if go_id in annotations[p_id]:
s += score
sim[j] = s / total_score
for go_id, score in zip(allgos, sim):
annots[go_id] = score
diamond_preds[prot_id] = annots
# Load CNN model
model = load_model(model_file)
# Alphas for the latest model
alphas = {NAMESPACES['mf']: 0.55, NAMESPACES['bp']: 0.59, NAMESPACES['cc']: 0.46}
# Alphas for the cafa2 model
# alphas = {NAMESPACES['mf']: 0.63, NAMESPACES['bp']: 0.68, NAMESPACES['cc']: 0.48}
start_time = time.time()
total_seq = 0
w = gzip.open(out_file, 'wt')
for prot_ids, sequences in read_fasta(in_file, chunk_size):
total_seq += len(prot_ids)
deep_preds = {}
ids, data = get_data(sequences)
preds = model.predict(data, batch_size=batch_size)
assert preds.shape[1] == len(terms)
for i, j in enumerate(ids):
prot_id = prot_ids[j]
if prot_id not in deep_preds:
deep_preds[prot_id] = {}
for l in range(len(terms)):
if preds[i, l] >= 0.01: # Filter out very low scores
if terms[l] not in deep_preds[prot_id]:
deep_preds[prot_id][terms[l]] = preds[i, l]
else:
deep_preds[prot_id][terms[l]] = max(
deep_preds[prot_id][terms[l]], preds[i, l])
# Combine diamond preds and deepgo
for prot_id in prot_ids:
annots = {}
if prot_id in diamond_preds:
for go_id, score in diamond_preds[prot_id].items():
annots[go_id] = score * alphas[go.get_namespace(go_id)]
for go_id, score in deep_preds[prot_id].items():
if go_id in annots:
annots[go_id] += (1 - alphas[go.get_namespace(go_id)]) * score
else:
annots[go_id] = (1 - alphas[go.get_namespace(go_id)]) * score
# Propagate scores with ontology structure
gos = list(annots.keys())
for go_id in gos:
for g_id in go.get_anchestors(go_id):
if g_id in annots:
annots[g_id] = max(annots[g_id], annots[go_id])
else:
annots[g_id] = annots[go_id]
sannots = sorted(annots.items(), key=lambda x: x[1], reverse=True)
for go_id, score in sannots:
if score >= threshold:
w.write(prot_id + '\t' + go_id + '\t' + go.get_namespace(go_id) + '\t' + go.get_term(go_id)['name'] + '\t%.2f' % go.get_ic(go_id) + '\t%.3f\n' % score)
w.write('\n')
w.close()
total_time = time.time() - start_time
print('Total prediction time for %d sequences is %d' % (total_seq, total_time))
def read_fasta(filename, chunk_size):
seqs = list()
info = list()
seq = ''
inf = ''
with gzip.open(filename, 'rt') as f:
for line in f:
line = line.strip()
if line.startswith('>'):
if seq != '':
seqs.append(seq)
info.append(inf)
if len(info) == chunk_size:
yield (info, seqs)
seqs = list()
info = list()
seq = ''
inf = line[1:].split()[0]
else:
seq += line
seqs.append(seq)
info.append(inf)
yield (info, seqs)
def get_data(sequences):
pred_seqs = []
ids = []
for i, seq in enumerate(sequences):
if len(seq) > MAXLEN:
st = 0
while st < len(seq):
pred_seqs.append(seq[st: st + MAXLEN])
ids.append(i)
st += MAXLEN - 128
else:
pred_seqs.append(seq)
ids.append(i)
n = len(pred_seqs)
data = np.zeros((n, MAXLEN, 21), dtype=np.float32)
for i in range(n):
seq = pred_seqs[i]
data[i, :, :] = to_onehot(seq)
return ids, data
if __name__ == '__main__':
main()