-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy patheval.py
76 lines (60 loc) · 2.74 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# coding: utf-8
from __future__ import print_function
import tensorflow as tf
from preprocessing import preprocessing_factory
import reader
import model
import time
import os
tf.app.flags.DEFINE_string('loss_model', 'vgg_16', 'The name of the architecture to evaluate. '
'You can view all the support models in nets/nets_factory.py')
tf.app.flags.DEFINE_integer('image_size', 256, 'Image size to train.')
tf.app.flags.DEFINE_string("model_file", "models.ckpt", "")
tf.app.flags.DEFINE_string("image_file", "a.jpg", "")
FLAGS = tf.app.flags.FLAGS
def main(_):
# Get image's height and width.
height = 0
width = 0
with open(FLAGS.image_file, 'rb') as img:
with tf.Session().as_default() as sess:
if FLAGS.image_file.lower().endswith('png'):
image = sess.run(tf.image.decode_png(img.read()))
else:
image = sess.run(tf.image.decode_jpeg(img.read()))
height = image.shape[0]
width = image.shape[1]
tf.logging.info('Image size: %dx%d' % (width, height))
with tf.Graph().as_default():
with tf.Session().as_default() as sess:
# Read image data.
image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn)
# Add batch dimension
image = tf.expand_dims(image, 0)
generated = model.net(image, training=False)
generated = tf.cast(generated, tf.uint8)
# Remove batch dimension
generated = tf.squeeze(generated, [0])
# Restore model variables.
saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# Use absolute path
FLAGS.model_file = os.path.abspath(FLAGS.model_file)
saver.restore(sess, FLAGS.model_file)
# Make sure 'generated' directory exists.
generated_file = 'generated/res.jpg'
if os.path.exists('generated') is False:
os.makedirs('generated')
# Generate and write image data to file.
with open(generated_file, 'wb') as img:
start_time = time.time()
img.write(sess.run(tf.image.encode_jpeg(generated)))
end_time = time.time()
tf.logging.info('Elapsed time: %fs' % (end_time - start_time))
tf.logging.info('Done. Please check %s.' % generated_file)
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()