Skip to content

[PyTorch] pytorch-lightning 套件介紹

Last Updated on 2022-04-07 by Clay

PyTorch Lightning 是把原生 PyTorch 封裝得更高級的框架套件,就像是 Keras 之於 Tensorflow 一樣(雖然 Keras 能支援的後端我記得是不少的)。

簡單來說,有許多人會認為 PyTorch 的某些操作實在是太底層,比方說要自己去寫 for 迴圈迭代訓練、並在之中清空累積梯度,自己手動向後傳播...... 等等。

當然這些底層的操作有時候可能很重要;比方說我們可能會希望多迭代幾次再進行向後傳播的動作。不過,確實,大部分時候我們並不需要這樣的功能。

PyTorch Lightning 所提供的 LightModule 其實就是更進一步封裝的 torch.nn.Module。按照官方文件的說明,LightModule 最主要的就是把以下本來需要寫的一些訓練程式碼包裝起來:

  • 訓練迴圈(train loop)
  • 驗證迴圈(validation loop)
  • 測試迴圈(test loop)
  • 預測迴圈(prediction loop)
  • 模型或系統模型(the model or system of models)< 我不太確定這部分究竟是封裝了什麼,應該是我尚未使用到
  • 優化器以及學習率

以下,就稍微紀錄一下該如何使用 PyTorch Lightning,並附上一段官方的範例程式碼並嘗試解釋。


如何使用 PyTorch Lightning

首先,自然是安裝環節。我們可以使用 pip 指令輕鬆安裝在裝置中,當然,推薦是在虛擬環境中。

pip3 install pytorch-lightning

可以使用:

pip3 list | grep pytorch-lightning

來查看版本:

pytorch-lightning           1.5.10

範例程式碼解釋

範例程式碼是一段官方提供的 AutoEncoder 模型,如果想看其他不同的任務,官方也有提供相關原始碼,這部分請參閱附在底下 References 中的連結。

以下我分成三個部分簡單介紹:

  • 匯入相關套件
  • 使用 LightningModule 架構模型
  • 實例化模型並進行訓練、驗證(沒有測試資料集)


匯入相關套件

當然,最重要的自然是 torch 以及 pytorch_lightning 的套件;而我們測試 AutoEncoder 模型的資料是使用經典的手寫數字辨識,故也要把 torchvision 一齊匯入。

如果環境中沒有安裝,也同樣需要透過 pip 指令去安裝,因為他們都是不同的套件。

import torch
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST


使用 LightingModule 建構模型

建構的模型是一個標準的 Python 類別,跟 PyTorch 建立的模型相似,但卻又有些不同。這裡我們建立了以下的初始化及方法(methods)。

  • __init__(): 建立 encoder 和 decoder 的模型架構
  • forward(): 單向向前傳播(無返回 loss)
  • configure_optimizers(): 建立優化器
  • training_step(): 封裝後的迭代訓練方法(須返回 loss)
  • validation_step(): 封裝後的迭代驗證方法

我一開始很不解為什麼 PyTorch 經典的 forward() 和 pytorch-lightning 封裝的 training_step() 要一起寫,但後來才發現官方文檔就清楚地寫著傾向於推薦訓練和推理分開來撰寫

也是,training_step() 還需要返回 loss...... 代表這是一整個訓練的完整流程,而我們要預測沒看過的資料,仍然需要使用 forward() 這樣的方法。


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3),
        )

        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28),
        )
    
    def forward(self, x):
        return self.encoder(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)

        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("val_loss", loss)



實例化模型並開始訓練

  • 首先建立 MNIST 資料集
  • 使用 random_split() 切分訓練資料集與驗證資料集
  • 使用 LitAutoEncoder() 實例化模型並訓練

要注意的是,第一次試跑記得設定 max_epochs 參數。因為預設是 1000,若用 CPU 的話恐怕會跑上好一陣子。

def main():
    # Data
    dataset = MNIST("", train=True, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [50000, 10000])

    train_loader = DataLoader(mnist_train, batch_size=32)
    val_loader = DataLoader(mnist_val, batch_size=32)

    # Model
    model = LitAutoEncoder()

    # Training
    trainer = pl.Trainer(gpus=0, precision=16, limit_train_batches=0.5, max_epochs=50)
    trainer.fit(model, train_loader, val_loader)


if __name__ == "__main__":
    main()

完整範例程式碼

若是想要直接試跑模型,可以參考:

# coding: utf-8
"""
This is a simple script for pytorch-lightning testing.
"""
import torch
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3),
        )

        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28),
        )
    
    def forward(self, x):
        return self.encoder(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)

        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("val_loss", loss)


def main():
    # Data
    dataset = MNIST("", train=True, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [50000, 10000])

    train_loader = DataLoader(mnist_train, batch_size=32)
    val_loader = DataLoader(mnist_val, batch_size=32)

    # Model
    model = LitAutoEncoder()

    # Training
    trainer = pl.Trainer(gpus=0, precision=16, limit_train_batches=0.5, max_epochs=50)
    trainer.fit(model, train_loader, val_loader)


if __name__ == "__main__":
    main()

References


Read More

Leave a Reply取消回覆

Exit mobile version