-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodel.py
114 lines (90 loc) · 3.25 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import tensorflow as tf
from tensorflow import keras
import metrics
import models.simple
import models.deepflow
import models.resnet
import models.densenet
import models.testmodel
import h5py
import json
def model_map(key):
return {
"simple_nn": models.simple.simple_nn,
"simple_nn_with_dropout": models.simple.simple_nn_with_dropout,
"simple_cnn_with_dropout": models.simple.simple_cnn_with_dropout,
"deepflow": models.deepflow.deepflow,
"deepflow_narrow": models.deepflow.deepflow_narrow,
"resnet50": resnet50,
"resnet18": resnet18,
"densenet": densenet,
"testmodel": testmodel
}[key]
def optimizer_map(key):
# def get_decay_mom(args):
# lr = tf.train.exponential_decay(
# args["learning_rate"],
# tf.train.get_or_create_global_step(),
# args["epochs_per_decay"],
# args["learning_rate_decay"],
# staircase=True
# )
# return tf.keras.optimizers.Momentum(lr, args["momentum"])
# def get_mom(args):
# return tf.keras.optimizers.Momentum(args["learning_rate"], args["momentum"])
def get_adam(args):
return tf.keras.optimizers.Adam(lr=args["learning_rate"], beta_1=args["beta1"], epsilon=args["epsilon"])
def get_adam_def(args):
return tf.keras.optimizers.Adam()
def get_rmsprop(args):
return tf.keras.optimizers.RMSProp(args["learning_rate"], momentum=args["momentum"])
return {
"adam": get_adam,
"adam_def": get_adam_def,
# "mom_decay": get_decay_mom,
# "mom": get_mom,
"rmsprop": get_rmsprop
}[key]
def load_model(args, load_weights=True):
m = model_map(args["model"])(args)
if load_weights:
m.load_weights(args["model_hdf5"])
return m
def build_model(args, m=None):
if m is None:
m = model_map(args["model"])(args)
optimizer = optimizer_map(args["optimizer"])(args)
bal_acc = metrics.BalancedAccuracy(args["noc"])
m.compile(
optimizer=optimizer,
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=[bal_acc]
)
return m
def resnet50(args):
s = (len(args["channels"]), args["image_width"], args["image_height"])
model = models.resnet.ResnetBuilder.build_resnet_50(s, args["noc"])
return model
def resnet18(args):
s = (len(args["channels"]), args["image_width"], args["image_height"])
model = models.resnet.ResnetBuilder.build_resnet_18(s, args["noc"])
return model
def densenet(args):
s = (len(args["channels"]), args["image_width"], args["image_height"])
builder = models.densenet.DenseNet(
input_shape=s,
nb_classes=args["noc"],
compression=args["compression"],
dropout_rate=args["dropout"],
dense_blocks=3 if "dense_blocks" not in args else args["dense_blocks"],
dense_layers=-1 if "dense_layers" not in args else args["dense_layers"],
growth_rate=args["growth_rate"],
weight_decay=args["l2"],
depth=None if "model_depth" not in args else args["model_depth"],
bottleneck=args["bottleneck"],
)
return builder.build_model()
def testmodel(args):
m = models.testmodel.TestModel(args)
# m.build((128, 3, 90, 90))
return m