Last Updated on 2021-08-25 by Clay
Early stopping is a technique applied to machine learning and deep learning, just as it means: early stopping. In the process of supervised learning, this is likely to be a way to find the time point for the model to converge.
People who have experience in model training generally know that if the model is trained for too many iterations, overfitting will occur.
In other words, the model already knows too much about the characteristics of our data, so it will perform extremely well on the training data, but it will become very bad on the test data.
That is the generalization of the model is not good.
This is a classic overfitting example figure. At first glance, the model predicts the training data very well, right?
Let's suppose that today we have encountered brand new information, as shown in the figure below.
Now that the test data with the green dots is added on the way, you will find that this model is completely unable to predict the data that has not been seen. In fact, the model will almost only memorize the answers to the training data, and cannot predict other data at all.
So, how do we know the model is trained be a good model? Here is the focus of this article: Early Stopping.
Early Stopping To Prevent Overfitting
Before we start, we can split a bit of data for validation data. The validation data will not join to model training, so the model will not see these data.
This step is very important.
If we calculate the loss functions of training data and validation data separately, we called them training loss and validation loss, we will see the phenomena like the following:
The horizontal axis is the number of iterations of our model (epochs), which can be regarded as the length of model training; the vertical axis is the loss of the data set. The larger the loss, the less accuracy of data prediction.
This is the principle of early stopping.
Since the model will gradually start overfitting, why not stop training when the loss of the validation data set starts to rise?
That's right, so the validation data we cut out can be used to evaluate the current effect of the training model. Some people will use Loos, some people will take accuracy... In fact, all kinds of metrics can be used. The key point is to stop training at the time when the model gradually loses generalization.
Sample Code
The following is a simple Mnist handwritten digit recognition code, the code is very simple.
# coding: utf-8 import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data from torchvision import datasets, transforms # Model architecture class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.main = nn.Sequential( nn.Linear(in_features=784, out_features=128), nn.ReLU(), nn.Linear(in_features=128, out_features=64), nn.ReLU(), nn.Linear(in_features=64, out_features=10), nn.LogSoftmax(dim=1) ) def forward(self, input): return self.main(input) # Train def train(device, model, epochs, optimizer, loss_function, train_loader): for epoch in range(1, epochs+1): for times, data in enumerate(train_loader, 1): inputs = data[0].to(device) labels = data[1].to(device) # Zero the gradients optimizer.zero_grad() # Forward and backward propagation outputs = model(inputs.view(inputs.shape[0], -1)) loss = loss_function(outputs, labels) loss.backward() optimizer.step() # Show progress if times % 100 == 0 or times == len(train_loader): print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, times, len(train_loader), loss.item())) return model def test(device, model, test_loader): # Settings model.eval() total = 0 correct = 0 with torch.no_grad(): for data in test_loader: inputs = data[0].to(device) labels = data[1].to(device) outputs = model(inputs.view(inputs.shape[0], -1)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accurecy:', correct / total) def main(): # GPU device device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('Device state:', device) # Settings epochs = 100 batch_size = 64 lr = 0.002 loss_function = nn.NLLLoss() model = Net().to(device) optimizer = optim.Adam(model.parameters(), lr=lr) # Transform transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] ) # Data train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform) test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform) train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False) # Train model = train(device, model, epochs, optimizer, loss_function, train_loader) # Test test(device, model, test_loader) if __name__ == '__main__': main()
Output:
I execute 100 iterations. It can be seen that in the end, the accuracy of Mnist classification is only 0.9767.
In terms of a simple data set such as Mnist, it should actually be higher. It may be guess that the model has begun to show an overfitting trend.
Then the following is the code for adding the early stopping mechanism after the change. It may not be so easy to read, basically only two changes have been made:
- Split the training data into a training data set and a validation data set
- After each iteration of training, set the model to the evaluation mode and calculate the Loss of the validation data set
- set patience (If it is set to 2, the training will stop if loss drops 2 times continuously)
# coding: utf-8 import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data from torchvision import datasets, transforms # Model architecture class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.main = nn.Sequential( nn.Linear(in_features=784, out_features=128), nn.ReLU(), nn.Linear(in_features=128, out_features=64), nn.ReLU(), nn.Linear(in_features=64, out_features=10), nn.LogSoftmax(dim=1) ) def forward(self, input): return self.main(input) # Train def train(device, model, epochs, optimizer, loss_function, train_loader, valid_loader): # Early stopping the_last_loss = 100 patience = 2 trigger_times = 0 for epoch in range(1, epochs+1): model.train() for times, data in enumerate(train_loader, 1): inputs = data[0].to(device) labels = data[1].to(device) # Zero the gradients optimizer.zero_grad() # Forward and backward propagation outputs = model(inputs.view(inputs.shape[0], -1)) loss = loss_function(outputs, labels) loss.backward() optimizer.step() # Show progress if times % 100 == 0 or times == len(train_loader): print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, times, len(train_loader), loss.item())) # Early stopping the_current_loss = validation(model, device, valid_loader, loss_function) print('The current loss:', the_current_loss) if the_current_loss > the_last_loss: trigger_times += 1 print('trigger times:', trigger_times) if trigger_times >= patience: print('Early stopping!\nStart to test process.') return model else: print('trigger times: 0') trigger_times = 0 the_last_loss = the_current_loss return model def validation(model, device, valid_loader, loss_function): # Settings model.eval() loss_total = 0 # Test validation data with torch.no_grad(): for data in valid_loader: inputs = data[0].to(device) labels = data[1].to(device) outputs = model(inputs.view(inputs.shape[0], -1)) loss = loss_function(outputs, labels) loss_total += loss.item() return loss_total / len(valid_loader) def test(device, model, test_loader): # Settings model.eval() total = 0 correct = 0 with torch.no_grad(): for data in test_loader: inputs = data[0].to(device) labels = data[1].to(device) outputs = model(inputs.view(inputs.shape[0], -1)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accurecy:', correct / total) def main(): # GPU device device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('Device state:', device) # Settings epochs = 100 batch_size = 64 lr = 0.002 loss_function = nn.NLLLoss() model = Net().to(device) optimizer = optim.Adam(model.parameters(), lr=lr) # Transform transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] ) # Data train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform) test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform) train_set_size = int(len(train_set) * 0.8) valid_set_size = len(train_set) - train_set_size train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size]) train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False) valid_loader = data.DataLoader(valid_set, batch_size=batch_size, shuffle=True) # Train model = train(device, model, epochs, optimizer, loss_function, train_loader, valid_loader) # Test test(device, model, test_loader) if __name__ == '__main__': main()
Of course, this is a sample code after all.
This data set is so small that random training is good. My opinion is that the early stopping mechanism does not necessarily improve the model, but it may provide us with a model that will converge. Time to go. Maybe the next time, the Early stopping mechanism will actually reduce the effect.
Just like my friend said: I have time, why don't I fucking train 1000 models in one breath, and then save each one, calculate the score of the test data set, and then directly pick the best?
Of course you can do that!
References
- https://github.com/Bjarten/early-stopping-pytorch
- https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html