MNIST Handwritten Digit Recognition in Keras

Gregor Koehler

In this article we'll build a simple neural network and train it on a GPU-enabled server to recognize handwritten digits using the MNIST dataset. Training a classifier on the MNIST dataset is regarded as the hello world of image recognition. To follow along here, you should have a basic understanding of the Multilayer Perceptron class of neural networks.

MNIST contains 70,000 images of handwritten digits: 60,000 for training and 10,000 for testing. The images are grayscale, 28x28 pixels, and centered to reduce preprocessing and get started quicker.

Keras is a high-level neural network API focused on user friendliness, fast prototyping, modularity and extensibility. It works with deep learning frameworks like Tensorflow, Theano and CNTK, so we can get right into building and training a neural network without a lot of fuss.

Let's get started by setting up our environment with Keras using Tensorflow as the backend.

Setting up the Environment

First, we have to install the Tensorflow and Keras packages. We do this in a separate runtime so that we can save it and export as a new environment and never have to install again.

250.9s
Solving environment: ...working... done absl-py-0.2.2 | 135 KB | ########## | 100% h5py-2.8.0 | 1.1 MB | ########## | 100% astor-0.6.2 | 42 KB | ########## | 100% protobuf-3.5.2 | 610 KB | ########## | 100% conda-4.5.4 | 1.0 MB | ########## | 100% termcolor-1.1.0 | 7 KB | ########## | 100% tensorflow-base-1.8. | 138.2 MB | ########## | 100% cupti-8.0.61 | 1.4 MB | ########## | 100% bleach-1.5.0 | 22 KB | ########## | 100% tensorflow-gpu-1.8.0 | 5 KB | ########## | 100% ca-certificates-2018 | 124 KB | ########## | 100% gast-0.2.0 | 15 KB | ########## | 100% markdown-2.6.11 | 104 KB | ########## | 100% tensorflow-1.8.0 | 5 KB | ########## | 100% cudatoolkit-8.0 | 322.4 MB | ########## | 100% tensorboard-1.8.0 | 3.1 MB | ########## | 100% certifi-2018.4.16 | 142 KB | ########## | 100% cudnn-7.0.5 | 249.3 MB | ########## | 100% hdf5-1.10.2 | 5.2 MB | ########## | 100% _tflow_180_select-1. | 2 KB | ########## | 100% grpcio-1.12.0 | 1.7 MB | ########## | 100% libprotobuf-3.5.2 | 4.2 MB | ########## | 100% openssl-1.0.2o | 3.4 MB | ########## | 100% html5lib-0.9999999 | 176 KB | ########## | 100% werkzeug-0.14.1 | 423 KB | ########## | 100% ## Package Plan ## environment location: /opt/conda added / updated specs: - cudatoolkit=8 - h5py - tensorflow-gpu The following packages will be downloaded: package | build ---------------------------|----------------- absl-py-0.2.2 | py36_0 135 KB anaconda h5py-2.8.0 | py36ha1f6525_0 1.1 MB anaconda astor-0.6.2 | py36_0 42 KB anaconda protobuf-3.5.2 | py36hf484d3e_0 610 KB anaconda conda-4.5.4 | py36_0 1.0 MB anaconda termcolor-1.1.0 | py36_1 7 KB anaconda tensorflow-base-1.8.0 | py36h4df133c_0 138.2 MB anaconda cupti-8.0.61 | 0 1.4 MB anaconda bleach-1.5.0 | py36_0 22 KB anaconda tensorflow-gpu-1.8.0 | h7b35bdc_0 5 KB anaconda ca-certificates-2018.03.07 | 0 124 KB anaconda gast-0.2.0 | py36_0 15 KB anaconda markdown-2.6.11 | py36_0 104 KB anaconda tensorflow-1.8.0 | hb381393_0 5 KB anaconda cudatoolkit-8.0 | 3 322.4 MB anaconda tensorboard-1.8.0 | py36hf484d3e_0 3.1 MB anaconda certifi-2018.4.16 | py36_0 142 KB anaconda cudnn-7.0.5 | cuda8.0_0 249.3 MB anaconda hdf5-1.10.2 | hba1933b_1 5.2 MB anaconda _tflow_180_select-1.0 | gpu 2 KB anaconda grpcio-1.12.0 | py36hdbcaa40_0 1.7 MB anaconda libprotobuf-3.5.2 | h6f1eeef_0 4.2 MB anaconda openssl-1.0.2o | h20670df_0 3.4 MB anaconda html5lib-0.9999999 | py36_0 176 KB anaconda werkzeug-0.14.1 | py36_0 423 KB anaconda ------------------------------------------------------------ Total: 732.8 MB The following NEW packages will be INSTALLED: _tflow_180_select: 1.0-gpu anaconda absl-py: 0.2.2-py36_0 anaconda astor: 0.6.2-py36_0 anaconda bleach: 1.5.0-py36_0 anaconda cudatoolkit: 8.0-3 anaconda cudnn: 7.0.5-cuda8.0_0 anaconda cupti: 8.0.61-0 anaconda gast: 0.2.0-py36_0 anaconda grpcio: 1.12.0-py36hdbcaa40_0 anaconda h5py: 2.8.0-py36ha1f6525_0 anaconda hdf5: 1.10.2-hba1933b_1 anaconda html5lib: 0.9999999-py36_0 anaconda libprotobuf: 3.5.2-h6f1eeef_0 anaconda markdown: 2.6.11-py36_0 anaconda protobuf: 3.5.2-py36hf484d3e_0 anaconda tensorboard: 1.8.0-py36hf484d3e_0 anaconda tensorflow: 1.8.0-hb381393_0 anaconda tensorflow-base: 1.8.0-py36h4df133c_0 anaconda tensorflow-gpu: 1.8.0-h7b35bdc_0 anaconda termcolor: 1.1.0-py36_1 anaconda werkzeug: 0.14.1-py36_0 anaconda The following packages will be UPDATED: ca-certificates: 2017.08.26-h1d4fec5_0 --> 2018.03.07-0 anaconda certifi: 2017.11.5-py36hf29ccca_0 --> 2018.4.16-py36_0 anaconda conda: 4.5.4-py36_0 conda-forge --> 4.5.4-py36_0 anaconda openssl: 1.0.2n-hb7f436b_0 --> 1.0.2o-h20670df_0 anaconda Downloading and Extracting Packages Preparing transaction: ...working... done Verifying transaction: ...working... done Executing transaction: ...working... done Collecting keras Downloading https://files.pythonhosted.org/packages/68/12/4cabc5c01451eb3b413d19ea151f36e33026fc0efb932bf51bcaf54acbf5/Keras-2.2.0-py2.py3-none-any.whl (300kB) Requirement already satisfied: six>=1.9.0 in /opt/conda/lib/python3.6/site-packages (from keras) (1.11.0) Collecting scipy>=0.14 (from keras) Downloading https://files.pythonhosted.org/packages/a8/0b/f163da98d3a01b3e0ef1cab8dd2123c34aee2bafbb1c5bffa354cc8a1730/scipy-1.1.0-cp36-cp36m-manylinux1_x86_64.whl (31.2MB) Collecting pyyaml (from keras) Downloading https://files.pythonhosted.org/packages/4a/85/db5a2df477072b2902b0eb892feb37d88ac635d36245a72a6a69b23b383a/PyYAML-3.12.tar.gz (253kB) Requirement already satisfied: numpy>=1.9.1 in /opt/conda/lib/python3.6/site-packages (from keras) (1.14.3) Collecting keras-preprocessing==1.0.1 (from keras) Downloading https://files.pythonhosted.org/packages/f8/33/275506afe1d96b221f66f95adba94d1b73f6b6087cfb6132a5655b6fe338/Keras_Preprocessing-1.0.1-py2.py3-none-any.whl Collecting keras-applications==1.0.2 (from keras) Downloading https://files.pythonhosted.org/packages/e2/60/c557075e586e968d7a9c314aa38c236b37cb3ee6b37e8d57152b1a5e0b47/Keras_Applications-1.0.2-py2.py3-none-any.whl (43kB) Requirement already satisfied: h5py in /opt/conda/lib/python3.6/site-packages (from keras) (2.8.0) Building wheels for collected packages: pyyaml Running setup.py bdist_wheel for pyyaml: started Running setup.py bdist_wheel for pyyaml: finished with status 'done' Stored in directory: /root/.cache/pip/wheels/03/05/65/bdc14f2c6e09e82ae3e0f13d021e1b6b2481437ea2f207df3f Successfully built pyyaml Installing collected packages: scipy, pyyaml, keras-preprocessing, keras-applications, keras Successfully installed keras-2.2.0 keras-applications-1.0.2 keras-preprocessing-1.0.1 pyyaml-3.12 scipy-1.1.0

We'll also import the dataset to cache it.

12.6s
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz 8192/11490434 [..............................] - ETA: 1:54 24576/11490434 [..............................] - ETA: 1:16 40960/11490434 [..............................] - ETA: 1:08 57344/11490434 [..............................] - ETA: 1:05 90112/11490434 [..............................] - ETA: 51s 106496/11490434 [..............................] - ETA: 52s 139264/11490434 [..............................] - ETA: 46s 180224/11490434 [..............................] - ETA: 41s 212992/11490434 [..............................] - ETA: 39s 262144/11490434 [..............................] - ETA: 35s 303104/11490434 [..............................] - ETA: 33s 352256/11490434 [..............................] - ETA: 31s 401408/11490434 [>.............................] - ETA: 29s 475136/11490434 [>.............................] - ETA: 26s 540672/11490434 [>.............................] - ETA: 24s 614400/11490434 [>.............................] - ETA: 23s 679936/11490434 [>.............................] - ETA: 22s 770048/11490434 [=>............................] - ETA: 20s 835584/11490434 [=>............................] - ETA: 19s 892928/11490434 [=>............................] - ETA: 19s 958464/11490434 [=>............................] - ETA: 18s 1032192/11490434 [=>............................] - ETA: 18s 1097728/11490434 [=>............................] - ETA: 17s 1171456/11490434 [==>...........................] - ETA: 17s 1236992/11490434 [==>...........................] - ETA: 17s 1327104/11490434 [==>...........................] - ETA: 16s 1392640/11490434 [==>...........................] - ETA: 16s 1466368/11490434 [==>...........................] - ETA: 15s 1556480/11490434 [===>..........................] - ETA: 15s 1638400/11490434 [===>..........................] - ETA: 14s 1712128/11490434 [===>..........................] - ETA: 14s 1794048/11490434 [===>..........................] - ETA: 14s 1884160/11490434 [===>..........................] - ETA: 13s 1949696/11490434 [====>.........................] - ETA: 13s 2039808/11490434 [====>.........................] - ETA: 13s 2129920/11490434 [====>.........................] - ETA: 13s 2211840/11490434 [====>.........................] - ETA: 12s 2301952/11490434 [=====>........................] - ETA: 12s 2392064/11490434 [=====>........................] - ETA: 12s 2473984/11490434 [=====>........................] - ETA: 11s 2564096/11490434 [=====>........................] - ETA: 11s 2646016/11490434 [=====>........................] - ETA: 11s 2736128/11490434 [======>.......................] - ETA: 11s 2826240/11490434 [======>.......................] - ETA: 11s 2908160/11490434 [======>.......................] - ETA: 10s 3014656/11490434 [======>.......................] - ETA: 10s 3104768/11490434 [=======>......................] - ETA: 10s 3186688/11490434 [=======>......................] - ETA: 10s 3276800/11490434 [=======>......................] - ETA: 10s 3366912/11490434 [=======>......................] - ETA: 9s 3465216/11490434 [========>.....................] - ETA: 9s 3555328/11490434 [========>.....................] - ETA: 9s 3645440/11490434 [========>.....................] - ETA: 9s 3727360/11490434 [========>.....................] - ETA: 9s 3817472/11490434 [========>.....................] - ETA: 9s 3915776/11490434 [=========>....................] - ETA: 8s 4005888/11490434 [=========>....................] - ETA: 8s 4096000/11490434 [=========>....................] - ETA: 8s 4177920/11490434 [=========>....................] - ETA: 8s 4268032/11490434 [==========>...................] - ETA: 8s 4358144/11490434 [==========>...................] - ETA: 8s 4456448/11490434 [==========>...................] - ETA: 8s 4546560/11490434 [==========>...................] - ETA: 7s 4636672/11490434 [===========>..................] - ETA: 7s 4718592/11490434 [===========>..................] - ETA: 7s 4808704/11490434 [===========>..................] - ETA: 7s 4915200/11490434 [===========>..................] - ETA: 7s 4997120/11490434 [============>.................] - ETA: 7s 5087232/11490434 [============>.................] - ETA: 7s 5177344/11490434 [============>.................] - ETA: 7s 5259264/11490434 [============>.................] - ETA: 6s 5365760/11490434 [=============>................] - ETA: 6s 5455872/11490434 [=============>................] - ETA: 6s 5537792/11490434 [=============>................] - ETA: 6s 5627904/11490434 [=============>................] - ETA: 6s 5709824/11490434 [=============>................] - ETA: 6s 5816320/11490434 [==============>...............] - ETA: 6s 5906432/11490434 [==============>...............] - ETA: 6s 5988352/11490434 [==============>...............] - ETA: 5s 6078464/11490434 [==============>...............] - ETA: 5s 6168576/11490434 [===============>..............] - ETA: 5s 6266880/11490434 [===============>..............] - ETA: 5s 6356992/11490434 [===============>..............] - ETA: 5s 6447104/11490434 [===============>..............] - ETA: 5s 6545408/11490434 [================>.............] - ETA: 5s 6635520/11490434 [================>.............] - ETA: 5s 6742016/11490434 [================>.............] - ETA: 5s 6823936/11490434 [================>.............] - ETA: 4s 6914048/11490434 [=================>............] - ETA: 4s 7020544/11490434 [=================>............] - ETA: 4s 7102464/11490434 [=================>............] - ETA: 4s 7208960/11490434 [=================>............] - ETA: 4s 7315456/11490434 [==================>...........] - ETA: 4s 7421952/11490434 [==================>...........] - ETA: 4s 7503872/11490434 [==================>...........] - ETA: 4s 7610368/11490434 [==================>...........] - ETA: 4s 7716864/11490434 [===================>..........] - ETA: 3s 7823360/11490434 [===================>..........] - ETA: 3s 7921664/11490434 [===================>..........] - ETA: 3s 8028160/11490434 [===================>..........] - ETA: 3s 8134656/11490434 [====================>.........] - ETA: 3s 8257536/11490434 [====================>.........] - ETA: 3s 8355840/11490434 [====================>.........] - ETA: 3s 8462336/11490434 [=====================>........] - ETA: 3s 8585216/11490434 [=====================>........] - ETA: 2s 8691712/11490434 [=====================>........] - ETA: 2s 8814592/11490434 [======================>.......] - ETA: 2s 8937472/11490434 [======================>.......] - ETA: 2s 9052160/11490434 [======================>.......] - ETA: 2s 9175040/11490434 [======================>.......] - ETA: 2s 9297920/11490434 [=======================>......] - ETA: 2s 9420800/11490434 [=======================>......] - ETA: 2s 9543680/11490434 [=======================>......] - ETA: 1s 9682944/11490434 [========================>.....] - ETA: 1s 9822208/11490434 [========================>.....] - ETA: 1s 9945088/11490434 [========================>.....] - ETA: 1s 10084352/11490434 [=========================>....] - ETA: 1s 10223616/11490434 [=========================>....] - ETA: 1s 10379264/11490434 [==========================>...] - ETA: 1s 10518528/11490434 [==========================>...] - ETA: 0s 10674176/11490434 [==========================>...] - ETA: 0s 10829824/11490434 [===========================>..] - ETA: 0s 10985472/11490434 [===========================>..] - ETA: 0s 11141120/11490434 [============================>.] - ETA: 0s 11304960/11490434 [============================>.] - ETA: 0s 11476992/11490434 [============================>.] - ETA: 0s 11493376/11490434 [==============================] - 10s 1us/step Using TensorFlow backend.

Preparing the Dataset

These package imports are pretty standard — we'll get back to the Keras-specific imports further down.

0.4s
Mon Feb 17 11:05:24 2020 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | |===============================+======================+======================| | 0 Tesla K80 On | 00000000:00:04.0 Off | 0 | | N/A 34C P8 28W / 149W | 6MiB / 11441MiB | 0% Default | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: GPU Memory | | GPU PID Type Process name Usage | |=============================================================================| +-----------------------------------------------------------------------------+
imports
1.4s
Using TensorFlow backend.

Now we'll load the dataset using this handy function which splits the MNIST data into train and test sets.

train-test-split
0.5s

Let's inspect a few examples. The MNIST dataset contains only grayscale images. For more advanced image datasets, we'll have the three color channels (RGB).

plot-examples
1.2s

In order to train our neural network to classify images we first have to unroll the height

\times
width pixel format into one big vector - the input vector. So its length must be
28 \cdot 28 = 784
. But let's graph the distribution of our pixel values.

pixel-distribution
0.7s

As expected, the pixel values range from 0 to 255: the background majority close to 0, and those close to 255 representing the digit.

Normalizing the input data helps to speed up the training. Also, it reduces the chance of getting stuck in local optima, since we're using stochastic gradient descent to find the optimal weights for the network.

Let's reshape our inputs to a single vector vector and normalize the pixel values to lie between 0 and 1.

input-formatting
0.3s
X_train shape (60000, 28, 28) y_train shape (60000,) X_test shape (10000, 28, 28) y_test shape (10000,) Train matrix shape (60000, 784) Test matrix shape (10000, 784)

So far the truth (Y in machine learning lingo) we'll use for training still holds integer values from 0 to 9.

y-value-counts
0.4s
(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8), array([5923, 6742, 5958, 6131, 5842, 5421, 5918, 6265, 5851, 5949]))

Let's encode our categories - digits from 0 to 9 - using one-hot encoding. The result is a vector with a length equal to the number of categories. The vector is all zeroes except in the position for the respective category. Thus a '5' will be represented by [0,0,0,0,1,0,0,0,0].

one-hot-encoding
0.3s
Shape before one-hot encoding: (60000,) Shape after one-hot encoding: (60000, 10)

Building the Network

Let's turn to Keras to build a neural network.

Our densely-connected network with two hidden layers.

Our pixel vector serves as the input. Then, two hidden 512-node layers, with enough model complexity for recognizing digits. For the multi-class classification we add another densely-connected (or fully-connected) layer for the 10 different output classes. For this network architecture we can use the Keras Sequential Model. We can stack layers using the .add() method.

When adding the first layer in the Sequential Model we need to specify the input shape so Keras can create the appropriate matrices. For all remaining layers the shape is inferred automatically.

In order to introduce nonlinearities into the network and elevate it beyond the capabilities of a simple perceptron we also add activation functions to the hidden layers. The differentiation for the training via backpropagation is happening behind the scenes without having to implement the details.

We also add dropout as a way to prevent overfitting. Here we randomly keep some network weights fixed when we would normally update them so that the network doesn't rely too much on very few nodes.

The last layer consists of connections for our 10 classes and the softmax activation which is standard for multi-class targets.

nn-setup
0.3s

Compiling and Training the Model

Now that the model is in place we configure the learning process using .compile(). Here we specify our loss function (or objective function). For our setting categorical cross entropy fits the bill, but in general other loss functions are available.

As for the optimizer of choice we'll use Adam with default settings. We could also instantiate an optimizer and set parameters before passing it to model.compile() but for this example the defaults will do.

We also choose which metrics will be evaluated during training and testing. We can pass any list of metrics - even build metrics ourselves - and have them displayed during training/testing. We'll stick to accuracy for now.

compile-model
0.3s

Having compiled our model we can now start the training process. We have to specify how many times we want to iterate on the whole training set (epochs) and how many samples we use for one update to the model's weights (batch size). Generally the bigger the batch, the more stable our stochastic gradient descent updates will be. But beware of GPU memory limitations! We're going for a batch size of 128 and 8 epochs.

To get a handle on our training progress we also graph the learning curve for our model looking at the loss and accuracy.

In order to work with the trained model and evaluate its performance we're saving the model in /results/.

train-model
55.6s
Train on 60000 samples, validate on 10000 samples Epoch 1/20 - 4s - loss: 0.2479 - acc: 0.9258 - val_loss: 0.1036 - val_acc: 0.9666 Epoch 2/20 - 3s - loss: 0.0997 - acc: 0.9694 - val_loss: 0.0812 - val_acc: 0.9740 Epoch 3/20 - 3s - loss: 0.0734 - acc: 0.9773 - val_loss: 0.0694 - val_acc: 0.9790 Epoch 4/20 - 3s - loss: 0.0545 - acc: 0.9820 - val_loss: 0.0765 - val_acc: 0.9756 Epoch 5/20 - 3s - loss: 0.0469 - acc: 0.9841 - val_loss: 0.0671 - val_acc: 0.9791 Epoch 6/20 - 3s - loss: 0.0389 - acc: 0.9876 - val_loss: 0.0612 - val_acc: 0.9820 Epoch 7/20 - 3s - loss: 0.0334 - acc: 0.9884 - val_loss: 0.0665 - val_acc: 0.9808 Epoch 8/20 - 3s - loss: 0.0287 - acc: 0.9902 - val_loss: 0.0618 - val_acc: 0.9831 Epoch 9/20 - 3s - loss: 0.0291 - acc: 0.9901 - val_loss: 0.0596 - val_acc: 0.9828 Epoch 10/20 - 3s - loss: 0.0251 - acc: 0.9917 - val_loss: 0.0705 - val_acc: 0.9799 Epoch 11/20 - 3s - loss: 0.0242 - acc: 0.9922 - val_loss: 0.0708 - val_acc: 0.9817 Epoch 12/20 - 3s - loss: 0.0205 - acc: 0.9932 - val_loss: 0.0751 - val_acc: 0.9810 Epoch 13/20 - 3s - loss: 0.0233 - acc: 0.9923 - val_loss: 0.0668 - val_acc: 0.9842 Epoch 14/20 - 3s - loss: 0.0191 - acc: 0.9940 - val_loss: 0.0887 - val_acc: 0.9791 Epoch 15/20 - 3s - loss: 0.0178 - acc: 0.9944 - val_loss: 0.0794 - val_acc: 0.9834 Epoch 16/20 - 3s - loss: 0.0179 - acc: 0.9941 - val_loss: 0.0707 - val_acc: 0.9836 Epoch 17/20 - 3s - loss: 0.0180 - acc: 0.9947 - val_loss: 0.0873 - val_acc: 0.9801 Epoch 18/20 - 3s - loss: 0.0181 - acc: 0.9938 - val_loss: 0.0723 - val_acc: 0.9836 Epoch 19/20 - 3s - loss: 0.0147 - acc: 0.9954 - val_loss: 0.0786 - val_acc: 0.9834 Epoch 20/20 - 3s - loss: 0.0145 - acc: 0.9956 - val_loss: 0.0741 - val_acc: 0.9822 Saved trained model at /results/keras_mnist.h5

This learning curve looks quite good! We see that the loss on the training set is decreasing rapidly for the first two epochs. This shows the network is learning to classify the digits pretty fast. For the test set the loss does not decrease as fast but stays roughly within the same range as the training loss. This means our model generalizes well to unseen data.

Evaluate the Model's Performance

It's time to reap the fruits of our neural network training. Let's see how well we the model performs on the test set. The model.evaluate() method computes the loss and any metric defined when compiling the model. So in our case the accuracy is computed on the 10,000 testing examples using the network weights given by the saved model.

evaluate
2.0s
Test Loss 0.07407343140110847 Test Accuracy 0.9822

This accuracy looks very good! But let's stay neutral here and evaluate both correctly and incorrectly classified examples. We'll look at 9 examples each.

evaluate-examples
2.4s
9822 classified correctly 178 classified incorrectly

As we can see, the wrong predictions are quite forgiveable since they're in some cases even hard to recognize for the human reader.

In summary we used Keras with a Tensorflow backend on a GPU-enabled server to train a neural network to recognize handwritten digits in under 20 seconds of training time - all that without having to spin up any compute instances, only using our browser.

In the next article of this series (coming soon) we'll harness the power of GPUs even more to train more complex neural networks which include convolutional layers.

Runtimes (2)