Skip to content

[Keras] 使用 CNN 進行 MNIST 的手寫數字辨識

Last Updated on 2021-04-10 by Clay

Mnist 是一個經典的手寫數字資料集,裡面的數字分別有從 0 到 9,共 10 種數字。今天我們的任務便是使用經典的 Keras 來搭建 CNN 的分類模型,以此來製作一個數字的分類器。

基本上原理的部份我想全部都留在『原理篇』裡面來闡述,今天就直接上 Code 吧!

事先聲明,我的模型選擇非常地淺白,只是單純地使用 Convolution 層 + MaxPool + Flatten,然後就是全部接 Dense 全連接層了。

還有更好的模型層選擇,我知道。不過為了方便筆記,就先這樣紀錄起來吧!何況效果也不怎麼差。

那麼,以下分成《MNIST 基本介紹》、《CNN 基本介紹》、《Keras 基本介紹》、《實戰部份》四大章節來講解。


MNIST 基本介紹

已經知道什麼是 Mnist dataset 的人可以直接跳過這一小節。

Mnist 資料集是一個廣為人知的手寫數字資料集,其地位可說是 Machine Learning 界的 Hello World 也不為過。

其中包含著 60000 張 Training data 的圖片、以及 10000 張 Test data 的圖片。聽說這總共 70000 張圖片是來自於高中學生以及人口普查的工作人員,每張的像素皆為 28 x 28,每個像素點都以一個灰階值來表示。

這個資料集珍貴的地方在於已經有標註 Label 了。分別是使用 one hot encoding 來標注 0 到 9。

one hot encoding 的意思像是:

0 : [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
1 : [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
.
.
.
8 : [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
9: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]

以上這樣的編碼形式。好處是可以處理非連續型的資料、壞處為當我們要採用的特徵數量太多時會白白浪費許多記憶體。


CNN 基本介紹

同樣地,若是對 CNN 有了解,也同樣可以跳過這小節。

不過話是這麼講,其實我也不會講什麼高深的理論。

CNN 的全名為 Convolution Neural Network,最早提出 CNN 這種概念的 paper 為這篇: http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf

裡面的介紹與現代人對 CNN 的理解差不多:使用 Convolution layer + pooling,再加上全連接層 ...... 大體上和現在差不多。


Keras 基本介紹

Keras 是一個深度學習的 API,使用 Python 編譯,其後端通常為 Tensorflow, CNTK, Theano。(現在的話,似乎大部分人都是使用 Tensorflow)

大部分主流的模型都可以使用 Keras 來快速地實現。

如果你想參閱他們官方的 document,你可以前往這裡:https://keras.io/


實戰部份

import os
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D
from keras.utils import np_utils, plot_model
from keras.datasets import mnist
import matplotlib.pyplot as plt
import pandas as pd



首先,我們需要匯入我們所需要的套件。順帶一提,如果 Keras 無法運行的話,建議安裝一下 tensorflow

pip3 install tensorflow

如果你有 GPU,那麼你該使用的是;

pip3 install tensorflow-gpu

如果你嫌自己電腦沒有 GPU、跑得慢,那麼,我推薦你使用 Google Colab。只要有 Google 的帳號都可以自由使用由 Google 提供的免費 GPU,對於小一點的項目(例如 Mnist),已經算得上相當相當快了。

Colab 的教學請看這裡: 如何使用 Google Colab 提供的免費 GPU

(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
x_train = X_train.reshape(60000, 1, 28, 28)/255
x_test = X_test.reshape(10000, 1, 28, 28)/255
y_train = np_utils.to_categorical(Y_train)
y_test = np_utils.to_categorical(Y_test)


Output:

Downloading data from <a rel="noreferrer noopener" target="_blank" href="https://s3.amazonaws.com/img-datasets/mnist.npz">https://s3.amazonaws.com/img-datasets/mnist.npz</a>
11493376/11490434 [==============================] - 2s 0us/step

這裡我們是直接拿 Keras 內的 Mnist 資料,x 全部都是圖片的資料,分成 Training data 跟 Test data;y 全部都是 one-hot encoding 的 Label,代表著圖片裡的數字是多少,同樣也是分 Training data 跟 Test data。

model = Sequential()
model.add(Conv2D(filters=32, kernel_size=3, input_shape=(1, 28, 28), activation='relu', padding='same'))
model.add(MaxPool2D(pool_size=2, data_format='channels_first'))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))


這裡是我們建立模型的部份,可以看到我們先是建立了 Convolution 層,然後接 MaxPool 層簡化圖片像素,然後 Flattern 攤平維度,最後接 Dense 全連接層,然後就輸出那 10 個類別了。

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=64, verbose=1)
loss, accuracy = model.evaluate(x_test, y_test)
print('Test:')
print('Loss: %s\nAccuracy: %s' % (loss, accuracy))


Output:

10000/10000 [==============================] - 1s 51us/step
Test:
Loss: 0.04998782943555562
Accuracy: 0.9853

完整程式

# -*- coding: utf-8 -*-
import os
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPool2D
from keras.utils import np_utils, plot_model
from keras.datasets import mnist
import matplotlib.pyplot as plt
import pandas as pd


# Mnist Dataset
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
x_train = X_train.reshape(60000, 1, 28, 28)/255
x_test = X_test.reshape(10000, 1, 28, 28)/255
y_train = np_utils.to_categorical(Y_train)
y_test = np_utils.to_categorical(Y_test)

# Model Structure
model = Sequential()
model.add(Conv2D(filters=32, kernel_size=3, input_shape=(1, 28, 28), activation='relu', padding='same'))
model.add(MaxPool2D(pool_size=2, data_format='channels_first'))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dense(10, activation='softmax'))
print(model.summary())

# Train
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=64, verbose=1)

# Test
loss, accuracy = model.evaluate(x_test, y_test)
print('Test:')
print('Loss: %s\nAccuracy: %s' % (loss, accuracy))

# Save model
model.save('./CNN_Mnist.h5')

# Load Model
model = load_model('./CNN_Mnist.h5')

# Display
def plot_img(n):
    plt.imshow(X_test[n], cmap='gray')
    plt.show()


def all_img_predict(model):
    print(model.summary())
    loss, accuracy = model.evaluate(x_test, y_test)
    print('Loss:', loss)
    print('Accuracy:', accuracy)
    predict = model.predict_classes(x_test)
    print(pd.crosstab(Y_test.reshape(-1), predict, rownames=['Label'], colnames=['predict']))


def one_img_predict(model, n):
    predict = model.predict_classes(x_test)
    print('Prediction:', predict[n])
    print('Answer:', Y_test[n])
    plot_img(n)


最後的 3 個韓式,其實你只需要使用 all_img_predict() 以及 one_img_predict() 就好。

前者是顯示你訓練的模型的所有預測結果,後者是你輸入 Test 的 index (0-9999),然後顯示圖片、預測、答案。

大家不妨自由地試試看各種模型吧?

1 thought on “[Keras] 使用 CNN 進行 MNIST 的手寫數字辨識”

Leave a Reply