Skip to content

[PyTorch] 保存優化器(optimizer)來接續訓練模型

我一直以來都希望能夠保存 PyTorch 訓練模型時所使用的優化器optimizer),以便能夠在模型結束訓練之後,繼續往下訓練;一般來說,如果是要做遷移學習、微調模型(fine-tune),那麼並不需要特別將上一次訓練的優化器保存下來。

我需要這麼做是有原因的:我碰到了一台奇妙的主機,該主機會在詭異的時間點完全停擺或 segmentation fault

然而,一樣的環境、一樣的程式碼,在別台(至少我測試過另外三台)跑起來都完全正常。我不得不懷疑是 GPU 硬體的問題,而且也在 GitHub 的某個 PyTorch 相關 issue 上看到有人回報 GPU 硬體有問題確實會發生跟我一樣的狀況。

簡而言之,我需要替我的訓練環境做好『備份』。一但系統當機,我完全可以從備份的時間點開始,繼續往下訓練模型。

而我們都知道,如果要保證 PyTorch 的訓練結果完全一模一樣,除了要固定一堆的 seed 種子外,我們也需要固定 DataLoader 物件提供的資料訓練順序(不知道為什麼,我無法固定任何亂數種子來實現這件事)。

但是這點其實我們可以自己設計,Dataset 如果是由我們自行撰寫,我們完全可以在 __getitem__() 中設計一個自己的順序,比方說將輸入的 index 無視掉,按照自己想要的順序回傳資料的特徵與標準答案。

另外,我們需要實時地保存模型的參數以及『優化器』。優化器是模型按照 loss 去更新權重參數的依據,對模型訓練可謂十分重要。

可以說,保存了模型的權重參數以及優化器的參數,我們就可以從某個斷點繼續往下訓練,並且訓練出跟一次訓練到底一模一樣的模型。

這是非常重要的一件事,關係到了模型的復現;除此之外,如果每次我遇到主機卡住,只能使用保存下來的模型繼續往下訓練,那麼很難保證模型真的會收斂。

使用同一個優化器是必須的


如何保存模型與優化器

這裡簡單紀錄如何保存的範例程式碼,其實程式十分簡單:

import pickle

# Save config
config = {
    "model_state_dict": model.state_dict(),
    "optim_state_dict": optimizer.state_dict(),
}

with open("config.pickle", "wb") as f:
    pickle.dump(config, f)


在這裡我使用了 pickle 套件來實現數據持久化(或稱資料持久化),如果對其有興趣,可以參考 [Python] 使用 Pickle 模組保存資料(持久化數據)

之後,如果希望將模型跟優化器還原到這次保存的時間點,可以使用以下程式碼:

import pickle

# Configs
if os.path.isfile("config.pickle"):
    with open("config.pickle", "rb") as f:
        config = pickle.load(f)
        model.load_state_dict(config["model_state_dict"])
        optimizer.load_state_dict(config["optim_state_dict"])


當然,這只是個範例程式碼,示範大致上是如何去跑儲存優化器、讀取優化器;實際上我們還需要宣告模型與優化器,這部分可以參考下方完整的程式碼。


完整程式碼

首先,先來跑個簡單的、固定結果的 MNIST 分類實驗。(對如何固定結果有興趣,也可以參考我之前寫過的文章:[PyTorch] 設置種子參數重現模型訓練結果

# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import random
import numpy as np


# Fixed seed 
seed = 123

np.random.seed(seed)
random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


# Classifier model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, inputs):
        return self.main(inputs)


# Train
def train(device, model, epochs, optimizer, loss_function, train_loader):
    for epoch in range(1, epochs+1):
        for times, data in enumerate(train_loader, 1):
            inputs = data[0].view(data[0].shape[0], -1).to(device)
            labels = data[1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward and backward propagation
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            # Show progress
            if times % 100 == 0 or times == len(train_loader):
                print("[{}/{}, {}/{}] loss: {:.8}".format(epoch, epochs, times, len(train_loader), loss.item()))

    return model


# Test
def test(device, model, test_loader):
    # Settings
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for data in test_loader:
            inputs = data[0].view(data[0].shape[0], -1).to(device)
            labels = data[1].to(device)

            outputs = model(inputs)
            _, predicts = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicts == labels).sum().item()

    print("Accuracy:", correct/total)


def main():
    # GPU
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print("Device State:", device)

    # Settings
    epochs = 3
    batch_size = 64
    lr = 0.002
    loss_function = nn.NLLLoss()
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Dataset
    train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
    test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    # Train
    model = train(device, model, epochs, optimizer, loss_function, train_loader)

    # Test
    test(device, model, test_loader)


if __name__ == "__main__":
    main()


Output:


這一次的結果是 0.9381,順帶一提不同版本的 torch、不同的 GPU、不同的實驗環境會訓練出不同結果的模型。

不過我測試的時候當然固定了這些部分,所以只會重現出一樣的結果。

好,那麼接著我們來看看加入了一段程式的結束點、以及保存和讀取模型和優化器的程式碼

程式的結束點設定在 epoch == 2 的時候,這樣來模擬主機發生錯誤的時機;接著我執行程式,如果當前目錄底下存在著 config.pickle 的設定資料,則讀取這些資料,從一個檢查點開始繼續往下訓練。

# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import os
import random
import numpy as np
import pickle


# Fixed seed 
seed = 123

np.random.seed(seed)
random.seed(seed)

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


# Classifier model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, inputs):
        return self.main(inputs)


# Train
def train(device, model, epochs, optimizer, loss_function, train_loader):
    for epoch in range(1, epochs+1):
        for times, data in enumerate(train_loader, 1):
            inputs = data[0].view(data[0].shape[0], -1).to(device)
            labels = data[1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward and backward propagation
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            # Show progress
            if times % 100 == 0 or times == len(train_loader):
                print("[{}/{}, {}/{}] loss: {:.8}".format(epoch, epochs, times, len(train_loader), loss.item()))

        # Save config
        config = {
            "epochs": epochs-epoch,
            "model_state_dict": model.state_dict(),
            "optim_state_dict": optimizer.state_dict(),
        }

        with open("config.pickle", "wb") as f:
            pickle.dump(config, f)
        
        # Break out
        if epoch == 2: exit()

    return model


# Test
def test(device, model, test_loader):
    # Settings
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for data in test_loader:
            inputs = data[0].view(data[0].shape[0], -1).to(device)
            labels = data[1].to(device)

            outputs = model(inputs)
            _, predicts = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicts == labels).sum().item()

    print("Accuracy:", correct/total)


def main():
    # GPU
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print("Device State:", device)

    # Settings
    epochs = 3
    batch_size = 64
    lr = 0.002
    loss_function = nn.NLLLoss()
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Configs
    if os.path.isfile("config.pickle"):
        with open("config.pickle", "rb") as f:
            config = pickle.load(f)
            epochs = config["epochs"]
            model.load_state_dict(config["model_state_dict"])
            optimizer.load_state_dict(config["optim_state_dict"])

    # Transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Dataset
    train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
    test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    # Train
    model = train(device, model, epochs, optimizer, loss_function, train_loader)

    # Test
    test(device, model, test_loader)


if __name__ == "__main__":
    main()


Output

我們可以訓練出完全一模一樣的模型,即使中間主機不小心當機了也無所謂。


References


Read More

Leave a Reply