forked from hunkim/DeepLearningZeroToAll
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlab-07-1-learning_rate_and_evaluation.py
142 lines (124 loc) · 4.19 KB
/
lab-07-1-learning_rate_and_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Lab 7 Learning rate and Evaluation
import tensorflow as tf
tf.set_random_seed(777) # for reproducibility
x_data = [[1, 2, 1],
[1, 3, 2],
[1, 3, 4],
[1, 5, 5],
[1, 7, 5],
[1, 2, 5],
[1, 6, 6],
[1, 7, 7]]
y_data = [[0, 0, 1],
[0, 0, 1],
[0, 0, 1],
[0, 1, 0],
[0, 1, 0],
[0, 1, 0],
[1, 0, 0],
[1, 0, 0]]
# Evaluation our model using this test dataset
x_test = [[2, 1, 1],
[3, 1, 2],
[3, 3, 4]]
y_test = [[0, 0, 1],
[0, 0, 1],
[0, 0, 1]]
X = tf.placeholder("float", [None, 3])
Y = tf.placeholder("float", [None, 3])
W = tf.Variable(tf.random_normal([3, 3]))
b = tf.Variable(tf.random_normal([3]))
# tf.nn.softmax computes softmax activations
# softmax = exp(logits) / reduce_sum(exp(logits), dim)
hypothesis = tf.nn.softmax(tf.matmul(X, W) + b)
# Cross entropy cost/loss
cost = tf.reduce_mean(-tf.reduce_sum(Y * tf.log(hypothesis), axis=1))
# Try to change learning_rate to small numbers
optimizer = tf.train.GradientDescentOptimizer(
learning_rate=1e-10).minimize(cost)
# Correct prediction Test model
prediction = tf.arg_max(hypothesis, 1)
is_correct = tf.equal(prediction, tf.arg_max(Y, 1))
accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))
# Launch graph
with tf.Session() as sess:
# Initialize TensorFlow variables
sess.run(tf.global_variables_initializer())
for step in range(201):
cost_val, W_val, _ = sess.run(
[cost, W, optimizer], feed_dict={X: x_data, Y: y_data})
print(step, cost_val, W_val)
# predict
print("Prediction:", sess.run(prediction, feed_dict={X: x_test}))
# Calculate the accuracy
print("Accuracy: ", sess.run(accuracy, feed_dict={X: x_test, Y: y_test}))
'''
when lr = 1.5
0 5.73203 [[-0.30548954 1.22985029 -0.66033536]
[-4.39069986 2.29670858 2.99386835]
[-3.34510708 2.09743214 -0.80419564]]
1 23.1494 [[ 0.06951046 0.29449689 -0.0999819 ]
[-1.95319986 -1.63627958 4.48935604]
[-0.90760708 -1.65020132 0.50593793]]
2 27.2798 [[ 0.44451016 0.85699677 -1.03748143]
[ 0.48429942 0.98872018 -0.57314301]
[ 1.52989244 1.16229868 -4.74406147]]
3 8.668 [[ 0.12396193 0.61504567 -0.47498202]
[ 0.22003263 -0.2470119 0.9268558 ]
[ 0.96035379 0.41933775 -3.43156195]]
4 5.77111 [[-0.9524312 1.13037777 0.08607888]
[-3.78651619 2.26245379 2.42393875]
[-3.07170963 3.14037919 -2.12054014]]
5 inf [[ nan nan nan]
[ nan nan nan]
[ nan nan nan]]
6 nan [[ nan nan nan]
[ nan nan nan]
[ nan nan nan]]
...
Prediction: [0 0 0]
Accuracy: 0.0
-------------------------------------------------
When lr = 1e-10
0 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
1 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
2 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
...
198 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
199 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
200 5.73203 [[ 0.80269563 0.67861295 -1.21728313]
[-0.3051686 -0.3032113 1.50825703]
[ 0.75722361 -0.7008909 -2.10820389]]
Prediction: [0 0 0]
Accuracy: 0.0
-------------------------------------------------
When lr = 0.1
0 5.73203 [[ 0.72881663 0.71536207 -1.18015325]
[-0.57753736 -0.12988332 1.60729778]
[ 0.48373488 -0.51433605 -2.02127004]]
1 3.318 [[ 0.66219079 0.74796319 -1.14612854]
[-0.81948912 0.03000021 1.68936598]
[ 0.23214608 -0.33772916 -1.94628811]]
2 2.0218 [[ 0.64342022 0.74127686 -1.12067163]
[-0.81161296 -0.00900121 1.72049117]
[ 0.2086665 -0.35079569 -1.909742 ]]
...
199 0.672261 [[-1.15377033 0.28146935 1.13632679]
[ 0.37484586 0.18958236 0.33544877]
[-0.35609841 -0.43973011 -1.25604188]]
200 0.670909 [[-1.15885413 0.28058422 1.14229572]
[ 0.37609792 0.19073224 0.33304682]
[-0.35536593 -0.44033223 -1.2561723 ]]
Prediction: [2 2 2]
Accuracy: 1.0
'''