Skip to content

Commit

Permalink
add classification example
Browse files Browse the repository at this point in the history
  • Loading branch information
yabata committed Aug 17, 2017
1 parent ff784f2 commit dd2b47d
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 0 deletions.
98 changes: 98 additions & 0 deletions doc/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,101 @@ The function ``prn.BPTT()`` uses the Back Propagation Through Time algorithm and
.. code-block:: matlab
g_bptt = BPTT(net,data);
Classification (MNIST Data)
------------------------

In this example a neural network is used to learn to recognize handwritten digits.
Thefore the MNIST dataset hosted on `Yann LeCun's website`_ is used.
The data set consists of 60,000 data points for training and 10,000 data points for testing. To reduce the size of the data file, here only 25,000 data points for training and 5,000 for testing are used.
Each data point is defined by an 28X28 pixel image (784 numbers) and the
corresponing number represesented by a 10 element vector
(one element for each digit 0,1,2,3,4,5,6,7,8,9). For the number n,
only the n-th element is 1, all others are zero. So the vector [0 0 0 0 0 1 0 0 0 0] represents the number 5.
A more detailed explanation of the MNIST data can be found in the `Tensorflow tutorial`_.

.. _ Yann LeCun's website: http://yann.lecun.com/exdb/mnist/page
.. _Tensorflow tutorial: https://www.tensorflow.org/get_started/mnist/beginners

Python
^^^^^^^^^^^

At first the needed packages are imported. pickle for reading the data, matplotlib for plotting the results, numpy for its random function and pyrenn for the neural network.

::

import matplotlib as mpl
import matplotlib.pyplot as plt
import pickle
import numpy as np
import pyrenn as prn

Then the training input data P and output (target) data Y as well as the test input data Ptest and output data Ytest is read from the given pickle file. Each image is defined by the value of its 784 pixel, so P is a 2d array of size (784,Q), where Q is the number of data samples (25,000). Y is defined by an 10 element vector, which gives us a 2d array of size (10,Q=5000).


::

mnist = pickle.load( open( "MNIST_data.pkl", "rb" ) )
P = mnist['P']
Y = mnist['Y']
Ptest = mnist['Ptest']
Ytest = mnist['Ytest']

Then the neural network is created. Since we have a system with 28*28 inputs and 10 outputs, we need a neural network with the same number of inputs and outputs. For this system we choose a neural network with one hidden layer with 10 neurons. Since there is no interconnection between the images, we do not need a recurrent network and no delayed inputs, so we do not have to change the delay inputs.

::

net = prn.CreateNN([28*28,10,10])

Because training the network with all the available training data would need a lot of memory and time, we randomly extract a batch of 1000 data samples and use it to train the network. Because we want to use as much information of our data as possible, we only perform one iteration (k_max=1) and then extract a new batch. In this example we do this 20 times, so we train the net for 20 iterations, but each iteration with new training data.
``verbose=True`` activates diplaying the error during training.

::

batch_size = 1000
number_of_batches=20

for i in range(number_of_batches):
r = np.random.randint(0,25000-batch_size)
Ptrain = P[:,r:r+batch_size]
Ytrain = Y[:,r:r+batch_size]

#Train NN with training data Ptrain=input and Ytrain=target
#Set maximum number of iterations k_max
#Set termination condition for Error E_stop
#The Training will stop after k_max iterations or when the Error <=E_stop
net = prn.train_LM(Ptrain,Ytrain,net,
verbose=True,k_max=1,E_stop=1e-5)
print('Batch No. ',i,' of ',number_of_batches)
After the training is finished, we can use the neural network. Therefore we choose 9 random samples of the test data set and use the input to calculate the NN outputs
Then we can plot the results, comparing the output of the neural network (number above the image) with the training (image).

::

idx = np.random.randint(0,5000-9)
P_ = Ptest[:,idx:idx+9]
Y_ = prn.NNOut(P_,net)

fig = plt.figure(figsize=[11,7])
gs = mpl.gridspec.GridSpec(3,3)

for i in range(9):
ax = fig.add_subplot(gs[i])
y_ = np.argmax(Y_[:,i]) #find index with highest value in NN output
p_ = P_[:,i].reshape(28,28) #Convert input data for plotting
ax.imshow(p_) #plot input data
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(str(y_), fontsize=18)
plt.show()

.. figure:: img/example_python_classification.png
:width: 95%
:align: center

Binary file added doc/img/example_python_classification.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 71 additions & 0 deletions python/examples/example_classification_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import pickle
import numpy as np
import pyrenn as prn

###
#Read Example Data
mnist = pickle.load( open( "MNIST_data.pkl", "rb" ) )
P = mnist['P']
Y = mnist['Y']
Ptest = mnist['Ptest']
Ytest = mnist['Ytest']


###
#Create and train NN

#create recurrent neural network with 28*28 inputs,
#1 hidden layers with 10 neurons
#and 10 outputs (one for each possible class/number)
#the NN uses no delayed or recurrent inputs/connections
net = prn.CreateNN([28*28,10,10])

batch_size = 1000
number_of_batches=20

for i in range(number_of_batches):
r = np.random.randint(0,25000-batch_size)
Ptrain = P[:,r:r+batch_size]
Ytrain = Y[:,r:r+batch_size]

#Train NN with training data Ptrain=input and Ytrain=target
#Set maximum number of iterations k_max
#Set termination condition for Error E_stop
#The Training will stop after k_max iterations or when the Error <=E_stop
net = prn.train_LM(Ptrain,Ytrain,net,
verbose=True,k_max=1,E_stop=1e-5)
print('Batch No. ',i,' of ',number_of_batches)



###
#Select Test data

#Choose random number 0...5000-9
idx = np.random.randint(0,5000-9)
#Select 9 random Test input data
P_ = Ptest[:,idx:idx+9]
#Calculate NN Output for the 9 random test inputs
Y_ = prn.NNOut(P_,net)


###
#PLOT
fig = plt.figure(figsize=[11,7])
gs = mpl.gridspec.GridSpec(3,3)

for i in range(9):

ax = fig.add_subplot(gs[i])

y_ = np.argmax(Y_[:,i]) #find index with highest value in NN output
p_ = P_[:,i].reshape(28,28) #Convert input data for plotting

ax.imshow(p_) #plot input data
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(str(y_), fontsize=18)

plt.show()

0 comments on commit dd2b47d

Please sign in to comment.