Skip to content

[PyTorch] Use Early Stopping To Stop Model Training At A Better Convergence Time

Last Updated on 2021-08-25 by Clay

Early stopping is a technique applied to machine learning and deep learning, just as it means: early stopping. In the process of supervised learning, this is likely to be a way to find the time point for the model to converge.

People who have experience in model training generally know that if the model is trained for too many iterations, overfitting will occur.

In other words, the model already knows too much about the characteristics of our data, so it will perform extremely well on the training data, but it will become very bad on the test data.

That is the generalization of the model is not good.

overfitting

This is a classic overfitting example figure. At first glance, the model predicts the training data very well, right?

Let's suppose that today we have encountered brand new information, as shown in the figure below.

overfitting

Now that the test data with the green dots is added on the way, you will find that this model is completely unable to predict the data that has not been seen. In fact, the model will almost only memorize the answers to the training data, and cannot predict other data at all.

overfitting

So, how do we know the model is trained be a good model? Here is the focus of this article: Early Stopping.


Early Stopping To Prevent Overfitting

Before we start, we can split a bit of data for validation data. The validation data will not join to model training, so the model will not see these data.

This step is very important.

If we calculate the loss functions of training data and validation data separately, we called them training loss and validation loss, we will see the phenomena like the following:

overfitting_loss

The horizontal axis is the number of iterations of our model (epochs), which can be regarded as the length of model training; the vertical axis is the loss of the data set. The larger the loss, the less accuracy of data prediction.

This is the principle of early stopping.

Since the model will gradually start overfitting, why not stop training when the loss of the validation data set starts to rise?

That's right, so the validation data we cut out can be used to evaluate the current effect of the training model. Some people will use Loos, some people will take accuracy... In fact, all kinds of metrics can be used. The key point is to stop training at the time when the model gradually loses generalization.


Sample Code

The following is a simple Mnist handwritten digit recognition code, the code is very simple.

# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms


# Model architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=784, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        return self.main(input)


# Train
def train(device, model, epochs, optimizer, loss_function, train_loader):
    for epoch in range(1, epochs+1):
        for times, data in enumerate(train_loader, 1):
            inputs = data[0].to(device)
            labels = data[1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward and backward propagation
            outputs = model(inputs.view(inputs.shape[0], -1))
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            # Show progress
            if times % 100 == 0 or times == len(train_loader):
                print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, times, len(train_loader), loss.item()))

    return model


def test(device, model, test_loader):
    # Settings
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for data in test_loader:
            inputs = data[0].to(device)
            labels = data[1].to(device)

            outputs = model(inputs.view(inputs.shape[0], -1))
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accurecy:', correct / total)


def main():
    # GPU device
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print('Device state:', device)

    # Settings
    epochs = 100
    batch_size = 64
    lr = 0.002
    loss_function = nn.NLLLoss()
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))]
    )

    # Data
    train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform)
    test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform)
    train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

    # Train
    model = train(device, model, epochs, optimizer, loss_function, train_loader)

    # Test
    test(device, model, test_loader)


if __name__ == '__main__':
    main()



Output:

I execute 100 iterations. It can be seen that in the end, the accuracy of Mnist classification is only 0.9767.

In terms of a simple data set such as Mnist, it should actually be higher. It may be guess that the model has begun to show an overfitting trend.

Then the following is the code for adding the early stopping mechanism after the change. It may not be so easy to read, basically only two changes have been made:

  1. Split the training data into a training data set and a validation data set
  2. After each iteration of training, set the model to the evaluation mode and calculate the Loss of the validation data set
  3. set patience (If it is set to 2, the training will stop if loss drops 2 times continuously)
# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms


# Model architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=784, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        return self.main(input)


# Train
def train(device, model, epochs, optimizer, loss_function, train_loader, valid_loader):
    # Early stopping
    the_last_loss = 100
    patience = 2
    trigger_times = 0

    for epoch in range(1, epochs+1):
        model.train()

        for times, data in enumerate(train_loader, 1):
            inputs = data[0].to(device)
            labels = data[1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward and backward propagation
            outputs = model(inputs.view(inputs.shape[0], -1))
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            # Show progress
            if times % 100 == 0 or times == len(train_loader):
                print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, times, len(train_loader), loss.item()))

        # Early stopping
        the_current_loss = validation(model, device, valid_loader, loss_function)
        print('The current loss:', the_current_loss)

        if the_current_loss > the_last_loss:
            trigger_times += 1
            print('trigger times:', trigger_times)

            if trigger_times >= patience:
                print('Early stopping!\nStart to test process.')
                return model

        else:
            print('trigger times: 0')
            trigger_times = 0

        the_last_loss = the_current_loss

    return model


def validation(model, device, valid_loader, loss_function):
    # Settings
    model.eval()
    loss_total = 0

    # Test validation data
    with torch.no_grad():
        for data in valid_loader:
            inputs = data[0].to(device)
            labels = data[1].to(device)

            outputs = model(inputs.view(inputs.shape[0], -1))
            loss = loss_function(outputs, labels)
            loss_total += loss.item()

    return loss_total / len(valid_loader)


def test(device, model, test_loader):
    # Settings
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for data in test_loader:
            inputs = data[0].to(device)
            labels = data[1].to(device)

            outputs = model(inputs.view(inputs.shape[0], -1))
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accurecy:', correct / total)


def main():
    # GPU device
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print('Device state:', device)

    # Settings
    epochs = 100
    batch_size = 64
    lr = 0.002
    loss_function = nn.NLLLoss()
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))]
    )

    # Data
    train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform)
    test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform)
   
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size])

    train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    valid_loader = data.DataLoader(valid_set, batch_size=batch_size, shuffle=True)

    # Train
    model = train(device, model, epochs, optimizer, loss_function, train_loader, valid_loader)

    # Test
    test(device, model, test_loader)


if __name__ == '__main__':
    main()



Of course, this is a sample code after all.

This data set is so small that random training is good. My opinion is that the early stopping mechanism does not necessarily improve the model, but it may provide us with a model that will converge. Time to go. Maybe the next time, the Early stopping mechanism will actually reduce the effect.

Just like my friend said: I have time, why don't I fucking train 1000 models in one breath, and then save each one, calculate the score of the test data set, and then directly pick the best?

Of course you can do that!


References


Read More

Leave a Reply