diff --git a/doc/examples.rst b/doc/examples.rst index ed69da6..5378fdd 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -994,3 +994,101 @@ The function ``prn.BPTT()`` uses the Back Propagation Through Time algorithm and .. code-block:: matlab g_bptt = BPTT(net,data); + + +Classification (MNIST Data) +------------------------ + +In this example a neural network is used to learn to recognize handwritten digits. +Thefore the MNIST dataset hosted on `Yann LeCun's website`_ is used. +The data set consists of 60,000 data points for training and 10,000 data points for testing. To reduce the size of the data file, here only 25,000 data points for training and 5,000 for testing are used. +Each data point is defined by an 28X28 pixel image (784 numbers) and the +corresponing number represesented by a 10 element vector +(one element for each digit 0,1,2,3,4,5,6,7,8,9). For the number n, +only the n-th element is 1, all others are zero. So the vector [0 0 0 0 0 1 0 0 0 0] represents the number 5. +A more detailed explanation of the MNIST data can be found in the `Tensorflow tutorial`_. + +.. _ Yann LeCun's website: http://yann.lecun.com/exdb/mnist/page +.. _Tensorflow tutorial: https://www.tensorflow.org/get_started/mnist/beginners + +Python +^^^^^^^^^^^ + +At first the needed packages are imported. pickle for reading the data, matplotlib for plotting the results, numpy for its random function and pyrenn for the neural network. + +:: + + import matplotlib as mpl + import matplotlib.pyplot as plt + import pickle + import numpy as np + import pyrenn as prn + +Then the training input data P and output (target) data Y as well as the test input data Ptest and output data Ytest is read from the given pickle file. Each image is defined by the value of its 784 pixel, so P is a 2d array of size (784,Q), where Q is the number of data samples (25,000). Y is defined by an 10 element vector, which gives us a 2d array of size (10,Q=5000). + + +:: + + mnist = pickle.load( open( "MNIST_data.pkl", "rb" ) ) + P = mnist['P'] + Y = mnist['Y'] + Ptest = mnist['Ptest'] + Ytest = mnist['Ytest'] + +Then the neural network is created. Since we have a system with 28*28 inputs and 10 outputs, we need a neural network with the same number of inputs and outputs. For this system we choose a neural network with one hidden layer with 10 neurons. Since there is no interconnection between the images, we do not need a recurrent network and no delayed inputs, so we do not have to change the delay inputs. + + :: + + net = prn.CreateNN([28*28,10,10]) + +Because training the network with all the available training data would need a lot of memory and time, we randomly extract a batch of 1000 data samples and use it to train the network. Because we want to use as much information of our data as possible, we only perform one iteration (k_max=1) and then extract a new batch. In this example we do this 20 times, so we train the net for 20 iterations, but each iteration with new training data. +``verbose=True`` activates diplaying the error during training. + + :: + + batch_size = 1000 + number_of_batches=20 + + for i in range(number_of_batches): + r = np.random.randint(0,25000-batch_size) + Ptrain = P[:,r:r+batch_size] + Ytrain = Y[:,r:r+batch_size] + + #Train NN with training data Ptrain=input and Ytrain=target + #Set maximum number of iterations k_max + #Set termination condition for Error E_stop + #The Training will stop after k_max iterations or when the Error <=E_stop + net = prn.train_LM(Ptrain,Ytrain,net, + verbose=True,k_max=1,E_stop=1e-5) + print('Batch No. ',i,' of ',number_of_batches) + +After the training is finished, we can use the neural network. Therefore we choose 9 random samples of the test data set and use the input to calculate the NN outputs +Then we can plot the results, comparing the output of the neural network (number above the image) with the training (image). + + :: + + idx = np.random.randint(0,5000-9) + P_ = Ptest[:,idx:idx+9] + Y_ = prn.NNOut(P_,net) + + fig = plt.figure(figsize=[11,7]) + gs = mpl.gridspec.GridSpec(3,3) + + for i in range(9): + + ax = fig.add_subplot(gs[i]) + + y_ = np.argmax(Y_[:,i]) #find index with highest value in NN output + p_ = P_[:,i].reshape(28,28) #Convert input data for plotting + + ax.imshow(p_) #plot input data + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(str(y_), fontsize=18) + + plt.show() + +.. figure:: img/example_python_classification.png + :width: 95% + :align: center + diff --git a/doc/img/example_python_classification.png b/doc/img/example_python_classification.png new file mode 100644 index 0000000..43916ca Binary files /dev/null and b/doc/img/example_python_classification.png differ diff --git a/python/examples/example_classification_mnist.py b/python/examples/example_classification_mnist.py new file mode 100644 index 0000000..31df89a --- /dev/null +++ b/python/examples/example_classification_mnist.py @@ -0,0 +1,71 @@ +import matplotlib as mpl +import matplotlib.pyplot as plt +import pickle +import numpy as np +import pyrenn as prn + +### +#Read Example Data +mnist = pickle.load( open( "MNIST_data.pkl", "rb" ) ) +P = mnist['P'] +Y = mnist['Y'] +Ptest = mnist['Ptest'] +Ytest = mnist['Ytest'] + + +### +#Create and train NN + +#create recurrent neural network with 28*28 inputs, +#1 hidden layers with 10 neurons +#and 10 outputs (one for each possible class/number) +#the NN uses no delayed or recurrent inputs/connections +net = prn.CreateNN([28*28,10,10]) + +batch_size = 1000 +number_of_batches=20 + +for i in range(number_of_batches): + r = np.random.randint(0,25000-batch_size) + Ptrain = P[:,r:r+batch_size] + Ytrain = Y[:,r:r+batch_size] + + #Train NN with training data Ptrain=input and Ytrain=target + #Set maximum number of iterations k_max + #Set termination condition for Error E_stop + #The Training will stop after k_max iterations or when the Error <=E_stop + net = prn.train_LM(Ptrain,Ytrain,net, + verbose=True,k_max=1,E_stop=1e-5) + print('Batch No. ',i,' of ',number_of_batches) + + + +### +#Select Test data + +#Choose random number 0...5000-9 +idx = np.random.randint(0,5000-9) +#Select 9 random Test input data +P_ = Ptest[:,idx:idx+9] +#Calculate NN Output for the 9 random test inputs +Y_ = prn.NNOut(P_,net) + + +### +#PLOT +fig = plt.figure(figsize=[11,7]) +gs = mpl.gridspec.GridSpec(3,3) + +for i in range(9): + + ax = fig.add_subplot(gs[i]) + + y_ = np.argmax(Y_[:,i]) #find index with highest value in NN output + p_ = P_[:,i].reshape(28,28) #Convert input data for plotting + + ax.imshow(p_) #plot input data + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(str(y_), fontsize=18) + +plt.show()