Last Updated on 2021-08-25 by Clay
Early stopping 是一種應用於機器學習、深度學習的技巧,正如字面上的意思 —— 較早地停止。在進行監督式學習的過程中,這很有可能是一個找到模型收斂時機點的方法。
訓練過模型的人肯定都知道,只要訓練過頭,模型就會發生所謂的 Overfitting(過擬合),過度地去擬合我們的訓練資料。當然,這個模型在我們的訓練資料上會表現得很好,可是在其他的資料、也就是所謂的測試資料上卻會顯得效果很差,也就是這個模型的泛化性不好。
這是一張經典的 Overfitting 範例圖,圖中黑點是我們的訓練資料,函式 f(x) 則是我們的模型函式。 乍看之下,模型將訓練資料預測得很好,對不對?
那假設今天我們遇到了全新的資料,如下圖。
現在途中加上了綠色點的測試資料,會發現這個模型完全無法預測沒有看過的資料 —— 實際上模型幾乎只會背訓練資料的答案了,根本無法預測其他的資料。如果模型只訓練到一半,說不定更符合這批資料的特性。
那麼,我們又該如何知道這個模型已經訓練得『差不多』了呢?這裡就要進入本文的重點 —— Early stopping 了。
防止 Overfitting 的 Early stopping
在我們使用訓練資料開始訓練模型前,我們可以將訓練資料再多切出一小塊,也就是所謂的驗證資料(Validation data)。驗證資料不會參與模型的訓練,也就是說模型從頭到尾都不會看到這些資料。
這個不會參與訓練過程的驗證資料是非常重要的,因為當我們訓練模型時,模型會漸漸學到訓練資料的特徵,知道哪些特徵可能得做出什麼樣的預測 ...... 那驗證資料呢?既然模型沒有看過,那模型當然不知道怎麼預測驗證資料才是正確的!
如果我們分別計算訓練資料和驗證資料的損失函數,也就是 training loss 和 validation loss,我們會看到像以下這樣的現象:
橫軸為我們模型迭代的次數(Epochs),可以視為模型訓練的時間長短;縱軸則為資料集的 Loss,Loss 越大代表資料預測越不準確。
藍色線段為我們訓練資料的 Loss、紅色線段則為我們驗證資料的 Loss。可以看到,在訓練經過一段時間之後,模型會漸漸開始失去預測其他資料的特性,只能專注在預測訓練資料上,這就是逐漸 Overfitting 的徵兆。
然後本文的主題 —— Early stopping 就出現了。既然模型會漸漸開始 Overfitting,那為什麼不在驗證資料集的 Loss 開始上升時就停止訓練呢?
沒錯,所以我們切出的驗證資料,就可以拿來評估訓練模型當前的效果。有些人會使用 Loss、有些人會拿 Accuracy ...... 其實各式各樣的指標都可以,重點是在模型漸漸失去泛化的時間點停止訓練。
以下為一段使用 PyTorch 搭建 Mnist 手寫數字辨識模型的程式,並嘗試加入以 Loss 作為評估指標的 Early stopping 機制。
範例程式碼
以下是一段單純的 Mnist 手寫數字辨識程式碼,程式碼非常地單純。
# coding: utf-8 import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data from torchvision import datasets, transforms # Model architecture class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.main = nn.Sequential( nn.Linear(in_features=784, out_features=128), nn.ReLU(), nn.Linear(in_features=128, out_features=64), nn.ReLU(), nn.Linear(in_features=64, out_features=10), nn.LogSoftmax(dim=1) ) def forward(self, input): return self.main(input) # Train def train(device, model, epochs, optimizer, loss_function, train_loader): for epoch in range(1, epochs+1): for times, data in enumerate(train_loader, 1): inputs = data[0].to(device) labels = data[1].to(device) # Zero the gradients optimizer.zero_grad() # Forward and backward propagation outputs = model(inputs.view(inputs.shape[0], -1)) loss = loss_function(outputs, labels) loss.backward() optimizer.step() # Show progress if times % 100 == 0 or times == len(train_loader): print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, times, len(train_loader), loss.item())) return model def test(device, model, test_loader): # Settings model.eval() total = 0 correct = 0 with torch.no_grad(): for data in test_loader: inputs = data[0].to(device) labels = data[1].to(device) outputs = model(inputs.view(inputs.shape[0], -1)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accurecy:', correct / total) def main(): # GPU device device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('Device state:', device) # Settings epochs = 100 batch_size = 64 lr = 0.002 loss_function = nn.NLLLoss() model = Net().to(device) optimizer = optim.Adam(model.parameters(), lr=lr) # Transform transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] ) # Data train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform) test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform) train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False) # Train model = train(device, model, epochs, optimizer, loss_function, train_loader) # Test test(device, model, test_loader) if __name__ == '__main__': main()
Output:
很乾脆地一口氣跑了 100 次迭代。可以看到,到最後 Mnist 分類的準確率只有 0.9767 —— 以 Mnist 這種簡單的資料集而言其實應該是可以更高才對的。也許可以猜測,模型已經開始出現 Overfitting 的趨勢了。
那麼以下是改動之後,加入了 Early stopping 機制的程式碼。可能並不是那麼地好閱讀,基本上只做了兩點改動:
- 將訓練資料切割成了訓練資料集和驗證資料集
- 在每次訓練一輪迭代後將模型設為評估模式,計算驗證資料集的 Loss
- 設定 patience,如果設 2,則 Loss 連續下降 2 次就要停止訓練
# coding: utf-8 import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data from torchvision import datasets, transforms # Model architecture class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.main = nn.Sequential( nn.Linear(in_features=784, out_features=128), nn.ReLU(), nn.Linear(in_features=128, out_features=64), nn.ReLU(), nn.Linear(in_features=64, out_features=10), nn.LogSoftmax(dim=1) ) def forward(self, input): return self.main(input) # Train def train(device, model, epochs, optimizer, loss_function, train_loader, valid_loader): # Early stopping the_last_loss = 100 patience = 2 trigger_times = 0 for epoch in range(1, epochs+1): model.train() for times, data in enumerate(train_loader, 1): inputs = data[0].to(device) labels = data[1].to(device) # Zero the gradients optimizer.zero_grad() # Forward and backward propagation outputs = model(inputs.view(inputs.shape[0], -1)) loss = loss_function(outputs, labels) loss.backward() optimizer.step() # Show progress if times % 100 == 0 or times == len(train_loader): print('[{}/{}, {}/{}] loss: {:.8}'.format(epoch, epochs, times, len(train_loader), loss.item())) # Early stopping the_current_loss = validation(model, device, valid_loader, loss_function) print('The current loss:', the_current_loss) if the_current_loss > the_last_loss: trigger_times += 1 print('trigger times:', trigger_times) if trigger_times >= patience: print('Early stopping!\nStart to test process.') return model else: print('trigger times: 0') trigger_times = 0 the_last_loss = the_current_loss return model def validation(model, device, valid_loader, loss_function): # Settings model.eval() loss_total = 0 # Test validation data with torch.no_grad(): for data in valid_loader: inputs = data[0].to(device) labels = data[1].to(device) outputs = model(inputs.view(inputs.shape[0], -1)) loss = loss_function(outputs, labels) loss_total += loss.item() return loss_total / len(valid_loader) def test(device, model, test_loader): # Settings model.eval() total = 0 correct = 0 with torch.no_grad(): for data in test_loader: inputs = data[0].to(device) labels = data[1].to(device) outputs = model(inputs.view(inputs.shape[0], -1)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accurecy:', correct / total) def main(): # GPU device device = 'cuda:0' if torch.cuda.is_available() else 'cpu' print('Device state:', device) # Settings epochs = 100 batch_size = 64 lr = 0.002 loss_function = nn.NLLLoss() model = Net().to(device) optimizer = optim.Adam(model.parameters(), lr=lr) # Transform transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] ) # Data train_set = datasets.MNIST(root='MNIST', download=True, train=True, transform=transform) test_set = datasets.MNIST(root='MNIST', download=True, train=False, transform=transform) train_set_size = int(len(train_set) * 0.8) valid_set_size = len(train_set) - train_set_size train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size]) train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False) valid_loader = data.DataLoader(valid_set, batch_size=batch_size, shuffle=True) # Train model = train(device, model, epochs, optimizer, loss_function, train_loader, valid_loader) # Test test(device, model, test_loader) if __name__ == '__main__': main()
當然,這畢竟只是範例程式碼,這個資料集小到可能隨便亂訓練效果都很好 —— 我的看法是 Early stopping 的機制並不一定能改進模型,但可能可以提供給我們一個模型會收斂的時機點。搞不好下一次,Early stopping 的機制反而讓效果下降了也說不定。
就像我朋友說的:我有時間,我幹麻不他媽一口氣訓練 1000 個模型,然後每個都存下來、算一次測試資料集的分數,然後直接挑最好的?
你當然也可以這麼做啊 XDDD
References
- https://github.com/Bjarten/early-stopping-pytorch
- https://pytorch-lightning.readthedocs.io/en/latest/early_stopping.html