Skip to content

jaredgorski/cnn4fun

Folders and files

NameName
Last commit message
Last commit date

Latest commit

aba05aa · Dec 2, 2020

History

22 Commits
Aug 26, 2020
Aug 30, 2020
Aug 30, 2020
Aug 27, 2020
Aug 26, 2020
Dec 2, 2020
Aug 26, 2020
Sep 3, 2020

Repository files navigation

cnn4fun

This is a rather basic Convolutional Neural Network.

The cnn package contains a primary cnn.CNN class as well as convolution, max-pooling, and softmax activation layers at cnn.layers.Conv, cnn.layers.MaxPool and cnn.layers.SoftMax, respectively. These layers can be configured along with the learning rate in order to fine-tune the training of the network. This network currently works with the MNIST handwritten digits dataset, which can be tested by running python run_mnist.py.

The network supports both grayscale and RGB images.

To run with the MNIST dataset:

  1. clone this repo locally
  2. have Python 3 and pip installed on your machine
  3. install dependencies with pip install -r requirements.txt
  4. run python run_mnist.py

Package usage

# package must exist locally, whether cloned or copied into a project
import cnn

# get training images (RGB or grayscale) and labels, ordered
training_images = get_ordered_images_list()
training_labels = get_ordered_labels_list()

# define list of classes
classes = ['cat', 'dog']

# initialize layer stack
layers = [
    cnn.layers.Conv(num_kernels=16, kernel_dimension=5, stride=1),
    cnn.layers.MaxPool(kernel_dimension=2, stride=2),
    cnn.layers.Conv(num_kernels=16, kernel_dimension=3, stride=1),
    cnn.layers.MaxPool(kernel_dimension=2, stride=2),
    cnn.layers.SoftMax(num_classes=2),
]

# initialize network object
net = cnn.CNN(layers)

# train
net.train(training_images, training_labels, classes, num_epochs=20, rate=0.001)

# get test image and label
test_image = get_dog_png()
test_label = 'dog'

# test model prediction
prediction_index = net.predict(test_image)

prediction = classes[prediction_index]
correct = prediction == test_label

Tests:

  • To run unit tests, run python -m pytest.

Releases

No releases published

Packages

No packages published

Languages