Skip to content

[PyTorch] Save the Optimizer to Continue Training the Model

I have always wanted to be able to save the optimizer used by PyTorch to train the model, so that can continue training after the model finishes the previous training.

There is a reason I need to do that: I have run in a weird device that shuts down or segmentation faults at weird points in time.

In general case, if you want to do transfer learning or fine-tune, you do not need to save the last trained optimizer.

In short, I need to “back up” my training progress. Once the system crashes, I can start from the checkpoint and continue to train the model.

And we all know that if we want to ensure the training results of PyTorch are exactly the same, in addition to fixing some seeds, we also need to fix the training data order.

In addition, we need to save the parameters of the model and the optimizer in real time. The optimizer is the basis for the model to update the weight parameters according to the loss, which is very important for model training.

It can be said by saving the weight parameters of the model and the parameters of the optimizer, we can continue training from a certain breakpoint, and train a model that is exactly the same as the one training.

It is necessary to use the same optimizer.


How to save the model and optimizer

Here is a sample code to simply record how to save, in fact, the program is very simple:

import pickle

# Save config
config = {
    "model_state_dict": model.state_dict(),
    "optim_state_dict": optimizer.state_dict(),
}

with open("config.pickle", "wb") as f:
    pickle.dump(config, f)


I use the pickle package to implement data persistence. If you are interested, you can refer How to use Pickle module to store data in Python

Afterwards, if you want to restore the model and optimizer to this saved point in time, you can use the following code:

import pickle

# Configs
if os.path.isfile("config.pickle"):
    with open("config.pickle", "rb") as f:
        config = pickle.load(f)
        model.load_state_dict(config["model_state_dict"])
        optimizer.load_state_dict(config["optim_state_dict"])


Of course, this is just an example code. If fact, we need to declare the model and optimizer, you can refer to the complete code below.


Complete Code

First, let’s run a simple, fixed-result MNIST classification experiment. (you can refer [PyTorch] Set Seed To Reproduce Model Training Results)

# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import random
import numpy as np


# Fixed seed 
seed = 123

np.random.seed(seed)
random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


# Classifier model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, inputs):
        return self.main(inputs)


# Train
def train(device, model, epochs, optimizer, loss_function, train_loader):
    for epoch in range(1, epochs+1):
        for times, data in enumerate(train_loader, 1):
            inputs = data[0].view(data[0].shape[0], -1).to(device)
            labels = data[1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward and backward propagation
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            # Show progress
            if times % 100 == 0 or times == len(train_loader):
                print("[{}/{}, {}/{}] loss: {:.8}".format(epoch, epochs, times, len(train_loader), loss.item()))

    return model


# Test
def test(device, model, test_loader):
    # Settings
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for data in test_loader:
            inputs = data[0].view(data[0].shape[0], -1).to(device)
            labels = data[1].to(device)

            outputs = model(inputs)
            _, predicts = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicts == labels).sum().item()

    print("Accuracy:", correct/total)


def main():
    # GPU
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print("Device State:", device)

    # Settings
    epochs = 3
    batch_size = 64
    lr = 0.002
    loss_function = nn.NLLLoss()
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Dataset
    train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
    test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    # Train
    model = train(device, model, epochs, optimizer, loss_function, train_loader)

    # Test
    test(device, model, test_loader)


if __name__ == "__main__":
    main()


Output:


The result is 0.9381. By the way, different versions of torch, different GPUs, and different experimental environments will train models with different results.

But of course I fixed these parts when I tested it, so it just reproduced the same results.

Okay, so let’s move on the end point where a program is added, and the code that saves and reads the model and optimizer.

The end point of the program is set at epoch == 2, so as to simulate the timing of the device error; then I execute the program, if there is config.pickle setting data in the current directory, read these data, from a checkpoint start to continue training.

# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import os
import random
import numpy as np
import pickle


# Fixed seed 
seed = 123

np.random.seed(seed)
random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


# Classifier model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, inputs):
        return self.main(inputs)


# Train
def train(device, model, epochs, optimizer, loss_function, train_loader):
    for epoch in range(1, epochs+1):
        for times, data in enumerate(train_loader, 1):
            inputs = data[0].view(data[0].shape[0], -1).to(device)
            labels = data[1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward and backward propagation
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            # Show progress
            if times % 100 == 0 or times == len(train_loader):
                print("[{}/{}, {}/{}] loss: {:.8}".format(epoch, epochs, times, len(train_loader), loss.item()))

        # Save config
        config = {
            "epochs": epochs-epoch,
            "model_state_dict": model.state_dict(),
            "optim_state_dict": optimizer.state_dict(),
        }

        with open("config.pickle", "wb") as f:
            pickle.dump(config, f)
        
        # Break out
        if epoch == 2: exit()

    return model


# Test
def test(device, model, test_loader):
    # Settings
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for data in test_loader:
            inputs = data[0].view(data[0].shape[0], -1).to(device)
            labels = data[1].to(device)

            outputs = model(inputs)
            _, predicts = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicts == labels).sum().item()

    print("Accuracy:", correct/total)


def main():
    # GPU
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print("Device State:", device)

    # Settings
    epochs = 3
    batch_size = 64
    lr = 0.002
    loss_function = nn.NLLLoss()
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Configs
    if os.path.isfile("config.pickle"):
        with open("config.pickle", "rb") as f:
            config = pickle.load(f)
            epochs = config["epochs"]
            model.load_state_dict(config["model_state_dict"])
            optimizer.load_state_dict(config["optim_state_dict"])

    # Transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Dataset
    train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
    test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    # Train
    model = train(device, model, epochs, optimizer, loss_function, train_loader)

    # Test
    test(device, model, test_loader)


if __name__ == "__main__":
    main()


Output

We can train exactly the same model, even if the intermediate host accidentally crashes.


References


Read More

Leave a Reply