-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathVariational_autoencoder.py
108 lines (87 loc) · 3.38 KB
/
Variational_autoencoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from keras.layers import Lambda, Input, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.losses import binary_crossentropy
from keras import backend as K
import numpy as np
import argparse
import os
def sampling(args):
z_mean, z_log_var = args
batch = K.shape(z_mean)[0]
dim = K.int_shape(z_mean)[1]
# default, random_normal with mean=0 and std=1.0
epsilon = K.random_normal(shape=(batch, dim))
return z_mean + K.exp(0.5 * z_log_var) * epsilon
# MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()
image_size = x_train.shape[1]
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
# network parameters
input_shape = (original_dim,)
intermediate_dim = 512
batch_size = 128
latent_dim = 2
epochs = 50
# VAE model = encoder + decoder
# build encoder model
inputs = Input(shape=input_shape, name='encoder_input')
x = Dense(intermediate_dim, activation='relu')(inputs)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
# instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
outputs = Dense(original_dim, activation='sigmoid')(x)
# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
# instantiate VAE model
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs, name='vae_mlp')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
help_ = "Input 2 MNIST test dataset image indices to compare (ideally between 0 to 9999)"
parser.add_argument("-i",
"--indices",
help=help_)
args = parser.parse_args()
reconstruction_loss = binary_crossentropy(inputs,
outputs)
reconstruction_loss *= original_dim
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
if os.path.isfile('vae_mnist.h5'):
vae.load_weights('vae_mnist.h5')
else:
# train the autoencoder
vae.fit(x_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, None))
vae.save_weights('vae_mnist.h5')
if args.indices:
idx2img = (args.indices).split(',')
idx2img = list(map(int, idx2img))
img = [x_test[idx2img[0]], x_test[idx2img[1]]]
print("Images are ", y_test[idx2img[0]], ' and ', y_test[idx2img[1]])
compare_images = np.reshape(img, [-1, original_dim])
mean, log_sigma, vector = encoder.predict(compare_images)
v1 = vector[0]
v2 = vector[1]
distance = np.sqrt((v1[0] - v2[0]) ** 2 + (v1[1] - v2[1]) ** 2)
print('vectors are ', v1, v2)
print('Distance between vectors is ', distance)