-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathskim2python3.py
58 lines (42 loc) · 2.06 KB
/
skim2python3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import tensorflow as tf
import numpy as np
import h5py
import config_path as path
input_z = tf.keras.Input(shape=(140, 140, 3))
z = tf.keras.applications.mobilenet.MobileNet(include_top=False, input_shape=(140, 140, 3), weights=None)(input_z)
z = tf.keras.layers.AveragePooling2D((4, 4), strides=1)(z)
branch_z = tf.keras.Model(inputs=input_z, outputs=z)
def tile(embed1):
embed = tf.keras.backend.tile(embed1, [1, 5, 5, 1])
return embed
inputs_x = tf.keras.Input(shape=(256, 256, 3))
inputs_z_ = tf.keras.Input(shape=(1, 1, 1024))
z_ = tf.keras.layers.Lambda(tile)(inputs_z_)
x = tf.keras.applications.mobilenet.MobileNet(include_top=False, input_shape=(256, 256, 3), weights=None)(inputs_x)
x = tf.keras.layers.AveragePooling2D((4, 4), strides=1)(x)
x = tf.keras.layers.Multiply()([x, z_])
x = tf.keras.layers.GlobalAveragePooling2D()(x)
z_in = tf.keras.layers.Flatten()(inputs_z_)
x = tf.keras.layers.Concatenate()([x, z_in])
x = tf.keras.layers.Dropout(0.5)(x)
pred = tf.keras.layers.Dense(1, activation='sigmoid')(x)
branch_search = tf.keras.Model(inputs=[inputs_z_, inputs_x], outputs=pred)
inputs_1 = tf.keras.Input(shape=(140, 140, 3))
inputs_2 = tf.keras.Input(shape=(256, 256, 3))
output = branch_search([branch_z(inputs_1), inputs_2])
model = tf.keras.Model(inputs=[inputs_1, inputs_2], outputs=output)
# fdata = h5py.File('./Siam/Skim_data.h5','r')
# search = fdata['search']
# template = fdata['template']
# labels = fdata['label']
#
#
# model.compile(optimizer=tf.keras.optimizers.Adam(0.001, decay=1e-2),
# loss='binary_crossentropy',
# metrics=['accuracy'])
#
# model.fit([template, search], labels, epochs=20, batch_size=32, validation_split=0.1)
branch_search.load_weights('/home/space/Documents/experiment/LT_baseline4/modules/skim/branch_search_n.h5')
branch_z.load_weights('/home/space/Documents/experiment/LT_baseline4/modules/skim/branch_z.h5')
branch_search.save('/home/space/Documents/experiment/LT_baseline4/modules/skim/branch_search_3.h5')
branch_z.save('/home/space/Documents/experiment/LT_baseline4/modules/skim/branch_z_3.h5')