Skip to content

[Keras] Use CNN to build a simple classifier to MNIST

Last Updated on 2021-05-02 by Clay

Introduction

Mnist is a classical database of handwritten digits. The number in it have [0-9]. Today I will note how to use Keras to build a CNN classifier to classify numbers.

I declare in advance, my model design is very easy, just only use convolution layer + MaxPool + Flatten, and connect to fully connected layer (Dense layer).

Have a better model design in somewhere, I know. But convenient to note how to use, I'm very free to choose the model. Anyway, the effect is not bad.

And then, there are four blocks: MNIST introduction, CNN introduction, keras introduction, code.


MNIST Introduction

If you already know MNIST, you can pass this section.

MNIST is a very famous dataset of handwritten, it just like "Hello World" in the field of machine learning.

There are 60,000 images to be training data, and 10,000 images to be test data. Pixels are 28 x 28, and every pixel is a grayscale value.

And every image has a label. That's why this dataset is so precious. It is use one-hot encodeing to labelled 0-9.

The meaning of one-hot encoding is:

0 : [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1 : [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
...
9 : [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
10: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]

Like this.

The advantage is that it can process non-continuous data. The downside is that when there are too many features, the memory will wasted.


CNN Introduction

Similarly, if you have an understanding of CNN, you can skip this section.

But then again, if fact, I don't know talk about deep theory.

The full name of CNN is Convolution Neural Network. The paper that first proposed the concept of CNN is this: http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf

The introduction of paper is not far from now, use convolution layer + pooling ......


Keras Introduction

Keras is an framework of deep learning, and we can use Python to coding. The backend of it always is Tensorflow, CNTK, Theano ...... the most common now is Tensorflow right now.

Most of the mainstream models can be implemented quickly using Keras.

If you want to refer Keras document, you can refer here: https://keras.io/


Program Code

import os
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D
from keras.utils import np_utils, plot_model
from keras.datasets import mnist
import matplotlib.pyplot as plt
import pandas as pd



First, we need import the package we need. By the way, if your keras can't worked, maybe you need to install Tensorflow.

pip3 install tensorflow

If you have GPU (And you processed the CUDA environment):

pip3 install tensorflow-gpu

If you have no GPU in your computer, that cause your training is so slow.

Maybe you can use "Google Colab" to use free online GPU which google provided.

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
x_train = X_train.reshape(60000, 1, 28, 28)/255
x_test = X_test.reshape(10000, 1, 28, 28)/255
y_train = np_utils.to_categorical(Y_train)
y_test = np_utils.to_categorical(Y_test)



Output:

Downloading data from https://s3.amazonaws.com/img-datasets/mnist.npz
11493376/11490434 [==============================] - 2s 0us/step

We download the MNIST data of Keras. x is training data (all of them are images), and y is label.

model = Sequential()
model.add(Conv2D(filters=32, kernel_size=3, input_shape=(1, 28, 28), activation='relu', padding='same'))
model.add(MaxPool2D(pool_size=2, data_format='channels_first'))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))



This is our model structure. You can see we build the convolution layer and MaxPool to simplify the images pixel, Flatten() can transfer images to be one dimension.

And we connected to Dense, output 10 classes.

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=64, verbose=1)
loss, accuracy = model.evaluate(x_test, y_test)
print('Test:')
print('Loss: %s\nAccuracy: %s' % (loss, accuracy))



Output:

10000/10000 [==============================] - 1s 51us/step
Test:
Loss: 0.04998782943555562
Accuracy: 0.9853

Fully Code

# -*- coding: utf-8 -*-
import os
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D
from keras.utils import np_utils, plot_model
from keras.datasets import mnist
import matplotlib.pyplot as plt
import pandas as pd


# Mnist Dataset
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
x_train = X_train.reshape(60000, 1, 28, 28)/255
x_test = X_test.reshape(10000, 1, 28, 28)/255
y_train = np_utils.to_categorical(Y_train)
y_test = np_utils.to_categorical(Y_test)

# Model Structure
model = Sequential()
model.add(Conv2D(filters=32, kernel_size=3, input_shape=(1, 28, 28), activation='relu', padding='same'))
model.add(MaxPool2D(pool_size=2, data_format='channels_first'))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))
print(model.summary())

# Train
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=64, verbose=1)
# Test
loss, accuracy = model.evaluate(x_test, y_test)
print('Test:')
print('Loss: %s\nAccuracy: %s' % (loss, accuracy))

# Save model
model.save('./CNN_Mnist.h5')

# Load Model
model = load_model('./CNN_Mnist.h5')

# Display
def plot_img(n):
    plt.imshow(X_test[n], cmap='gray')
    plt.show()


def all_img_predict(model):
    print(model.summary())
    loss, accuracy = model.evaluate(x_test, y_test)
    print('Loss:', loss)
    print('Accuracy:', accuracy)
    predict = model.predict_classes(x_test)
    print(pd.crosstab(Y_test.reshape(-1), predict, rownames=['Label'], colnames=['predict']))


def one_img_predict(model, n):
    predict = model.predict_classes(x_test)
    print('Prediction:', predict[n])
    print('Answer:', Y_test[n])
    plot_img(n)



The 3 functions of last, you just need to call all_img_predict() and one_img_predict().

Welcome to try any model!

Leave a ReplyCancel reply

Exit mobile version