Skip to content

Commit

Permalink
add visualize embeddings codes
Browse files Browse the repository at this point in the history
  • Loading branch information
lawlite19 committed Oct 26, 2018
1 parent b3be604 commit af9a4c9
Show file tree
Hide file tree
Showing 17 changed files with 159 additions and 4 deletions.
Binary file modified code/.DS_Store
Binary file not shown.
Binary file added code/tensorflow-tools/.DS_Store
Binary file not shown.
Binary file added code/tensorflow-tools/log/.DS_Store
Binary file not shown.
2 changes: 2 additions & 0 deletions code/tensorflow-tools/log/checkpoint
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_checkpoint_path: "model.ckpt-0"
all_model_checkpoint_paths: "model.ckpt-0"
Binary file not shown.
Binary file added code/tensorflow-tools/log/images.ckpt.index
Binary file not shown.
Binary file added code/tensorflow-tools/log/images.ckpt.meta
Binary file not shown.
100 changes: 100 additions & 0 deletions code/tensorflow-tools/log/metadata.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
7
2
1
0
4
1
4
9
5
9
0
6
9
0
1
5
9
7
3
4
9
6
6
5
4
0
7
4
0
1
3
1
3
4
7
2
7
1
2
1
1
7
4
2
3
5
1
2
4
4
6
3
5
5
6
0
4
1
9
5
7
8
9
3
7
4
6
4
3
0
7
0
2
9
1
7
3
2
9
7
7
6
2
7
8
4
7
3
6
1
3
6
9
3
1
4
1
7
6
9
Binary file added code/tensorflow-tools/log/mnist_10k_sprite.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file added code/tensorflow-tools/log/model.ckpt-0.index
Binary file not shown.
Binary file added code/tensorflow-tools/log/model.ckpt-0.meta
Binary file not shown.
9 changes: 9 additions & 0 deletions code/tensorflow-tools/log/projector_config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
embeddings {
tensor_name: "embedding:0"
metadata_path: "metadata.tsv"
sprite {
image_path: "mnist_10k_sprite.png"
single_image_dim: 28
single_image_dim: 28
}
}
43 changes: 43 additions & 0 deletions code/tensorflow-tools/tensorflow_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#-*- coding: utf-8 -*-
# Author: Lawlite
# Date: 2017/07/26
# Associate Blog: http://lawlite.me/2017/06/24/Tensorflow学习-工具相关/#1、可视化embedding
# License: MIT

import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
from tensorflow.examples.tutorials.mnist import input_data
import os

MNIST_DATA_PATH = 'MNIST_data'
LOG_DIR = 'log'
SPRITE_IMAGE_FILE = 'mnist_10k_sprite.png'
META_DATA_FILE = 'metadata.tsv'
IMAGE_NUM = 100

mnist = input_data.read_data_sets(MNIST_DATA_PATH, one_hot=False)
plot_array = mnist.test.images[:IMAGE_NUM] # 取前100个图片
np.savetxt(os.path.join(LOG_DIR, META_DATA_FILE), mnist.test.labels[:IMAGE_NUM], fmt='%d') # label 保存为metadata.tsv文件


'''可视化embedding, 3个步骤'''
with tf.Session() as sess:
'''1、 将2D矩阵放入Variable中国'''
embeddings_var = tf.Variable(plot_array, name='embedding')
tf.global_variables_initializer().run()

'''2、 保存到文件中'''
saver = tf.train.Saver([embeddings_var])
sess.run(embeddings_var.initializer)
saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"), global_step=0)

'''3、 关联metadata和sprite图片'''
summary_writer = tf.summary.FileWriter(LOG_DIR)
config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embeddings_var.name
embedding.metadata_path = META_DATA_FILE
embedding.sprite.image_path = SPRITE_IMAGE_FILE
embedding.sprite.single_image_dim.extend([28, 28])
projector.visualize_embeddings(summary_writer, config)
Binary file modified code/triplet-loss/.DS_Store
Binary file not shown.
6 changes: 3 additions & 3 deletions code/triplet-loss/test_triplet_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,11 @@ def test_batch_hard_triplet_loss():

#test_pairwise_distances()
#test_gradients_pairwise_distances()
#test_anchor_positive_triplet_mask()
#test_anchor_negative_triplet_mask()
#test_triplet_mask()
#test_pairwise_distances_are_positive()
test_batch_all_triplet_loss()
#test_batch_all_triplet_loss()
#test_anchor_positive_triplet_mask()
#test_anchor_negative_triplet_mask()
#test_batch_hard_triplet_loss()


Expand Down
3 changes: 2 additions & 1 deletion code/triplet-loss/triplet_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def _get_triplet_mask(labels):

def batch_all_triplet_loss(labels, embeddings, margin, squared=False):
'''
triplet loss of a batch
triplet loss of a batch, 注意这里的loss一般不是收敛的,因为是计算的semi-hard和hard的距离均值,因为每次是先选择出semi-hard和hard
triplet, 那么上次优化后的可能就选择不到了,所以loss并不会收敛,但是fraction_postive_triplets是收敛的,因为随着优化占的比例是越来越少的
-------------------------------
Args:
labels: 标签数据,shape = (batch_size,)
Expand Down

0 comments on commit af9a4c9

Please sign in to comment.