Skip to content

[PyTorch] 應用 Early stopping 技術讓模型在較好的收斂時間點停止訓練

Last Updated on 2021-08-25 by Clay

Early stopping 是一種應用於機器學習、深度學習的技巧,正如字面上的意思 —— 較早地停止。在進行監督式學習的過程中,這很有可能是一個找到模型收斂時機點的方法。

訓練過模型的人肯定都知道,只要訓練過頭,模型就會發生所謂的 Overfitting過擬合),過度地去擬合我們的訓練資料。當然,這個模型在我們的訓練資料上會表現得很好,可是在其他的資料、也就是所謂的測試資料上卻會顯得效果很差,也就是這個模型的泛化性不好。

overfitting

這是一張經典的 Overfitting 範例圖,圖中黑點是我們的訓練資料,函式 f(x) 則是我們的模型函式。 乍看之下,模型將訓練資料預測得很好,對不對?

那假設今天我們遇到了全新的資料,如下圖。

overfitting

現在途中加上了綠色點的測試資料,會發現這個模型完全無法預測沒有看過的資料 —— 實際上模型幾乎只會背訓練資料的答案了,根本無法預測其他的資料。如果模型只訓練到一半,說不定更符合這批資料的特性。

overfitting

那麼,我們又該如何知道這個模型已經訓練得『差不多』了呢?這裡就要進入本文的重點 —— Early stopping 了。


防止 Overfitting 的 Early stopping

在我們使用訓練資料開始訓練模型前,我們可以將訓練資料再多切出一小塊,也就是所謂的驗證資料Validation data)。驗證資料不會參與模型的訓練,也就是說模型從頭到尾都不會看到這些資料。

這個不會參與訓練過程的驗證資料是非常重要的,因為當我們訓練模型時,模型會漸漸學到訓練資料的特徵,知道哪些特徵可能得做出什麼樣的預測 …… 那驗證資料呢?既然模型沒有看過,那模型當然不知道怎麼預測驗證資料才是正確的!

如果我們分別計算訓練資料和驗證資料的損失函數,也就是 training lossvalidation loss,我們會看到像以下這樣的現象:

overfitting_loss

橫軸為我們模型迭代的次數(Epochs),可以視為模型訓練的時間長短;縱軸則為資料集的 Loss,Loss 越大代表資料預測越不準確。

藍色線段為我們訓練資料的 Loss、紅色線段則為我們驗證資料的 Loss。可以看到,在訓練經過一段時間之後,模型會漸漸開始失去預測其他資料的特性,只能專注在預測訓練資料上,這就是逐漸 Overfitting 的徵兆。

然後本文的主題 —— Early stopping 就出現了。既然模型會漸漸開始 Overfitting,那為什麼不在驗證資料集的 Loss 開始上升時就停止訓練呢?

沒錯,所以我們切出的驗證資料,就可以拿來評估訓練模型當前的效果。有些人會使用 Loss、有些人會拿 Accuracy …… 其實各式各樣的指標都可以,重點是在模型漸漸失去泛化的時間點停止訓練。

以下為一段使用 PyTorch 搭建 Mnist 手寫數字辨識模型的程式,並嘗試加入以 Loss 作為評估指標的 Early stopping 機制。


範例程式碼

以下是一段單純的 Mnist 手寫數字辨識程式碼,程式碼非常地單純。

# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms


# Model architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=784, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        return self.main(input)


# Train
def train(device, model, epochs, optimizer, loss_function, train_loader):
    for epoch in range(1, epochs+1):
        for times, data in enumerate(train_loader, 1):
            inputs = data[0].to(device)
            labels = data[1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward and backward propagation
            outputs = model(inputs.view(inputs.shape[0], -1))
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            # Show progress
            if times % 100 == 0 or times == len(train_loader):
                print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, times, len(train_loader), loss.item()))

    return model


def test(device, model, test_loader):
    # Settings
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for data in test_loader:
            inputs = data[0].to(device)
            labels = data[1].to(device)

            outputs = model(inputs.view(inputs.shape[0], -1))
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accurecy:', correct / total)


def main():
    # GPU device
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print('Device state:', device)

    # Settings
    epochs = 100
    batch_size = 64
    lr = 0.002
    loss_function = nn.NLLLoss()
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))]
    )

    # Data
    train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform)
    test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform)
    train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

    # Train
    model = train(device, model, epochs, optimizer, loss_function, train_loader)

    # Test
    test(device, model, test_loader)


if __name__ == '__main__':
    main()



Output:

很乾脆地一口氣跑了 100 次迭代。可以看到,到最後 Mnist 分類的準確率只有 0.9767 —— 以 Mnist 這種簡單的資料集而言其實應該是可以更高才對的。也許可以猜測,模型已經開始出現 Overfitting 的趨勢了。

那麼以下是改動之後,加入了 Early stopping 機制的程式碼。可能並不是那麼地好閱讀,基本上只做了兩點改動:

  1. 將訓練資料切割成了訓練資料集和驗證資料集
  2. 在每次訓練一輪迭代後將模型設為評估模式,計算驗證資料集的 Loss
  3. 設定 patience,如果設 2,則 Loss 連續下降 2 次就要停止訓練
# coding: utf-8
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms


# Model architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(in_features=784, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, input):
        return self.main(input)


# Train
def train(device, model, epochs, optimizer, loss_function, train_loader, valid_loader):
    # Early stopping
    the_last_loss = 100
    patience = 2
    trigger_times = 0

    for epoch in range(1, epochs+1):
        model.train()

        for times, data in enumerate(train_loader, 1):
            inputs = data[0].to(device)
            labels = data[1].to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward and backward propagation
            outputs = model(inputs.view(inputs.shape[0], -1))
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            # Show progress
            if times % 100 == 0 or times == len(train_loader):
                print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, times, len(train_loader), loss.item()))

        # Early stopping
        the_current_loss = validation(model, device, valid_loader, loss_function)
        print('The current loss:', the_current_loss)

        if the_current_loss > the_last_loss:
            trigger_times += 1
            print('trigger times:', trigger_times)

            if trigger_times >= patience:
                print('Early stopping!\nStart to test process.')
                return model

        else:
            print('trigger times: 0')
            trigger_times = 0

        the_last_loss = the_current_loss

    return model


def validation(model, device, valid_loader, loss_function):
    # Settings
    model.eval()
    loss_total = 0

    # Test validation data
    with torch.no_grad():
        for data in valid_loader:
            inputs = data[0].to(device)
            labels = data[1].to(device)

            outputs = model(inputs.view(inputs.shape[0], -1))
            loss = loss_function(outputs, labels)
            loss_total += loss.item()

    return loss_total / len(valid_loader)


def test(device, model, test_loader):
    # Settings
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for data in test_loader:
            inputs = data[0].to(device)
            labels = data[1].to(device)

            outputs = model(inputs.view(inputs.shape[0], -1))
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accurecy:', correct / total)


def main():
    # GPU device
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    print('Device state:', device)

    # Settings
    epochs = 100
    batch_size = 64
    lr = 0.002
    loss_function = nn.NLLLoss()
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Transform
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,))]
    )

    # Data
    train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform)
    test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform)
   
    train_set_size = int(len(train_set) * 0.8)
    valid_set_size = len(train_set) - train_set_size
    train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size])

    train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)
    valid_loader = data.DataLoader(valid_set, batch_size=batch_size, shuffle=True)

    # Train
    model = train(device, model, epochs, optimizer, loss_function, train_loader, valid_loader)

    # Test
    test(device, model, test_loader)


if __name__ == '__main__':
    main()



當然,這畢竟只是範例程式碼,這個資料集小到可能隨便亂訓練效果都很好 —— 我的看法是 Early stopping 的機制並不一定能改進模型,但可能可以提供給我們一個模型會收斂的時機點。搞不好下一次,Early stopping 的機制反而讓效果下降了也說不定。

就像我朋友說的:我有時間,我幹麻不他媽一口氣訓練 1000 個模型,然後每個都存下來、算一次測試資料集的分數,然後直接挑最好的?

你當然也可以這麼做啊 XDDD


References


Read More

Leave a Reply