Skip to content

[PyTorch] Tutorial(4) Train a model to classify MNIST dataset

“Use a toy dataset to train a classification model” is a simplest deep learning practice.

Today I want to record how to use MNIST A HANDWRITTEN DIGIT RECOGNITION dataset to build a simple classifier in PyTorch.

This time the model is simpler than the previous CNN. It is a simple model that only uses a fully connected layer.

By the way, I found that the official tutorial website did not teach about MNIST, so I provide official website: https://pytorch.org/tutorials/


model parameters

If you don’t know what is MNIST, maybe you can refer: [Keras] Use CNN to build a simple classifier to MNIST dataset

So below, I will start to explain my code briefly. The complete code will be placed at the end of the article.

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as dset
from torchvision import datasets, transforms


Import the necessary packages.

# GPU
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print('GPU State:', device)


Check you GPU can be used, if you have no GPU, you can use CPU to instead it.

# Transform
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,)),]
)


The torchvision in PyTorch has a module called transforms, which can combine multiple transform functions into a List data type. It is mainly used for image conversion.

transformsToTensor(): will transform the PIL.Image with the value [0-255] into (C, H, W). Why isn’t Numpy common HWC sorting? I have seen the saying that it is because convolution is faster.

transformsNormalize(): It needs to be converted to Tensor type first, and then normalized.

# Data
trainSet = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform)
testSet = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform)
trainLoader = dset.DataLoader(trainSet, batch_size=64, shuffle=True)
testLoader = dset.DataLoader(testSet, batch_size=64, shuffle=False)


Here we prepare our information. If you don’t see the “MNIST” folder under the current folder, the program will automatically download and create “MNIST” from datasets in PyTorch.

# Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=784, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        return self.main(input)


net = Net().to(device)
print(net)


Output:

Net(
  (main): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=10, bias=True)
    (5): LogSoftmax()
  )
)

The above is the architecture of our model. First, create a “fully connected layer” with 784 pixel input and 128 neurons output, and then connect to the next layer through the activation function ReLU() (max(0, x) mapping function) ……Finally, since we want to output the classification of 10 labels, that is, 10 kinds of numbers [0-9], we finally output 10 values, and finally connect LogSoftmx().

Here I only know that Softmax is mostly used for multi-class prediction, and LogSoftmax is its improvement. For the detailed principles, I am ignorant and shallow… I will post another post to supplement it later.

The forward() function is the so-called forward propagation process. Here I directly call self.main(), which is a method that many people would choose.

# Parameters
epochs = 3
lr = 0.002
criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.002, momentum=0.9)

Here is the parameter settings:

epochs: Number of training iterations
lr: The abbreviation of learning rate, means our backward propagation learning rate.
criterion: The loss function
optimizer: For the optimizer we use, momentum is that the gradient update amplitude in the same direction is getting larger and larger, and the gradient in the opposite direction is getting smaller, usually set to 0.9. Limited by the study time, I have no time to test other sizes.

# Train
for epoch in range(epochs):
    running_loss = 0.0

    for times, data in enumerate(trainLoader):
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs = inputs.view(inputs.shape[0], -1)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Foward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if times % 100 == 99 or times+1 == len(trainLoader):
            print('[%d/%d, %d/%d] loss: %.3f' % (epoch+1, epochs, times+1, len(trainLoader), running_loss/2000))

print('Training Finished.')


This is where we start to train the model. It is worth mentioning that since the first layer of our fully connected layer needs to input 784 dimensions (the same input as my Keras program), we need to use view() to press the dimensions of inputs to match the input of the model.

# Test
correct = 0
total = 0

with torch.no_grad():
    for data in testLoader:
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs = inputs.view(inputs.shape[0], -1)

        outputs = net(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100*correct / total))

class_correct = [0 for i in range(10)]
class_total = [0 for i in range(10)]

with torch.no_grad():
    for data in testLoader:
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs = inputs.view(inputs.shape[0], -1)

        outputs = net(inputs)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(10):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1
            print(class_correct)
            print(class_total)

for i in range(10):
    print('Accuracy of %d: %3f' % (i, (class_correct[i]/class_total[i])))


Output:

Accuracy of the network on the 10000 test images: 92 %
Accuracy of 0: 1.000000
Accuracy of 1: 0.972603
Accuracy of 2: 0.865672
Accuracy of 3: 0.857143
Accuracy of 4: 0.941176
Accuracy of 5: 0.875000
Accuracy of 6: 0.892857
Accuracy of 7: 0.898305
Accuracy of 8: 0.950000
Accuracy of 9: 0.896104

Read More

Leave a Reply