Last Updated on 2022-07-13 by Clay
I have always wanted to be able to save the optimizer used by PyTorch to train the model, so that can continue training after the model finishes the previous training.
There is a reason I need to do that: I have run in a weird device that shuts down or segmentation faults at weird points in time.
In general case, if you want to do transfer learning or fine-tune, you do not need to save the last trained optimizer.
In short, I need to “back up” my training progress. Once the system crashes, I can start from the checkpoint and continue to train the model.
And we all know that if we want to ensure the training results of PyTorch are exactly the same, in addition to fixing some seeds, we also need to fix the training data order.
In addition, we need to save the parameters of the model and the optimizer in real time. The optimizer is the basis for the model to update the weight parameters according to the loss, which is very important for model training.
It can be said by saving the weight parameters of the model and the parameters of the optimizer, we can continue training from a certain breakpoint, and train a model that is exactly the same as the one training.
It is necessary to use the same optimizer.
How to save the model and optimizer
Here is a sample code to simply record how to save, in fact, the program is very simple:
import pickle
# Save config
config = {
"model_state_dict": model.state_dict(),
"optim_state_dict": optimizer.state_dict(),
}
with open("config.pickle", "wb") as f:
pickle.dump(config, f)
I use the pickle package to implement data persistence. If you are interested, you can refer How to use Pickle module to store data in Python
Afterwards, if you want to restore the model and optimizer to this saved point in time, you can use the following code:
import pickle
# Configs
if os.path.isfile("config.pickle"):
with open("config.pickle", "rb") as f:
config = pickle.load(f)
model.load_state_dict(config["model_state_dict"])
optimizer.load_state_dict(config["optim_state_dict"])
Of course, this is just an example code. If fact, we need to declare the model and optimizer, you can refer to the complete code below.
Complete Code
First, let’s run a simple, fixed-result MNIST classification experiment. (you can refer [PyTorch] Set Seed To Reproduce Model Training Results)
# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import random
import numpy as np
# Fixed seed
seed = 123
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# Classifier model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.main = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.LogSoftmax(dim=1)
)
def forward(self, inputs):
return self.main(inputs)
# Train
def train(device, model, epochs, optimizer, loss_function, train_loader):
for epoch in range(1, epochs+1):
for times, data in enumerate(train_loader, 1):
inputs = data[0].view(data[0].shape[0], -1).to(device)
labels = data[1].to(device)
# Zero the gradients
optimizer.zero_grad()
# Forward and backward propagation
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
# Show progress
if times % 100 == 0 or times == len(train_loader):
print("[{}/{}, {}/{}] loss: {:.8}".format(epoch, epochs, times, len(train_loader), loss.item()))
return model
# Test
def test(device, model, test_loader):
# Settings
model.eval()
total = 0
correct = 0
with torch.no_grad():
for data in test_loader:
inputs = data[0].view(data[0].shape[0], -1).to(device)
labels = data[1].to(device)
outputs = model(inputs)
_, predicts = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicts == labels).sum().item()
print("Accuracy:", correct/total)
def main():
# GPU
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device State:", device)
# Settings
epochs = 3
batch_size = 64
lr = 0.002
loss_function = nn.NLLLoss()
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
# Transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Dataset
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
# Train
model = train(device, model, epochs, optimizer, loss_function, train_loader)
# Test
test(device, model, test_loader)
if __name__ == "__main__":
main()
Output:
The result is 0.9381. By the way, different versions of torch, different GPUs, and different experimental environments will train models with different results.
But of course I fixed these parts when I tested it, so it just reproduced the same results.
Okay, so let’s move on the end point where a program is added, and the code that saves and reads the model and optimizer.
The end point of the program is set at epoch == 2
, so as to simulate the timing of the device error; then I execute the program, if there is config.pickle setting data in the current directory, read these data, from a checkpoint start to continue training.
# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
import random
import numpy as np
import pickle
# Fixed seed
seed = 123
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# Classifier model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.main = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10),
nn.LogSoftmax(dim=1)
)
def forward(self, inputs):
return self.main(inputs)
# Train
def train(device, model, epochs, optimizer, loss_function, train_loader):
for epoch in range(1, epochs+1):
for times, data in enumerate(train_loader, 1):
inputs = data[0].view(data[0].shape[0], -1).to(device)
labels = data[1].to(device)
# Zero the gradients
optimizer.zero_grad()
# Forward and backward propagation
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
# Show progress
if times % 100 == 0 or times == len(train_loader):
print("[{}/{}, {}/{}] loss: {:.8}".format(epoch, epochs, times, len(train_loader), loss.item()))
# Save config
config = {
"epochs": epochs-epoch,
"model_state_dict": model.state_dict(),
"optim_state_dict": optimizer.state_dict(),
}
with open("config.pickle", "wb") as f:
pickle.dump(config, f)
# Break out
if epoch == 2: exit()
return model
# Test
def test(device, model, test_loader):
# Settings
model.eval()
total = 0
correct = 0
with torch.no_grad():
for data in test_loader:
inputs = data[0].view(data[0].shape[0], -1).to(device)
labels = data[1].to(device)
outputs = model(inputs)
_, predicts = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicts == labels).sum().item()
print("Accuracy:", correct/total)
def main():
# GPU
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print("Device State:", device)
# Settings
epochs = 3
batch_size = 64
lr = 0.002
loss_function = nn.NLLLoss()
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
# Configs
if os.path.isfile("config.pickle"):
with open("config.pickle", "rb") as f:
config = pickle.load(f)
epochs = config["epochs"]
model.load_state_dict(config["model_state_dict"])
optimizer.load_state_dict(config["optim_state_dict"])
# Transform
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Dataset
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
# Train
model = train(device, model, epochs, optimizer, loss_function, train_loader)
# Test
test(device, model, test_loader)
if __name__ == "__main__":
main()
Output
We can train exactly the same model, even if the intermediate host accidentally crashes.
References
- https://pytorch.org/tutorials/beginner/saving_loading_models.html
- https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
- https://discuss.pytorch.org/t/why-save-optimizer-state-dict/108185