Skip to content

[Machine Learning] Introduction of 'pytorch-lightning' package

Last Updated on 2022-07-23 by Clay

PyTorch Lightning is a framework that encapsulates native PyTorch in a more advanced level, just like Keras does to Tensorflow (although I remember a lot of backends that Keras can support).

To put it simply, many people think that some PyTorch operations are too low-level, for example, to write for-loop iterative training, manually clear the accumulated gradient in it, and manually do backward propagation... and many more.

Of course these low-level operations are very important! For example, we might want to iterate a few more times before doing the back propagation. But, simple is good.

The LightModule provided by PyTorch Lightning is actually a further encapsulated torch.nn.Module.

According to the instructions of official documents, LightModule is to encapsulate the following training codes that need to be written:

  • Train loop
  • Validation loop
  • Test loop
  • Prediction loop
  • The model or system of models (I am not sure what this part is encapsulating, it should be that I have not used it yet)
  • Optimizer and learning rate

Below, I will briefly record how to use PyTorch Lightning, and attach an official sample code and try to explain it.


How to use

First is the installation part. We can use pip command to install pytorch-lightning. Of course, I recommend install at a Python virtual environment.

pip3 install pytorch-lightning


And you can check it has been existed.

pip3 list | grep pytorch-lightning



Check the version:

pytorch-lightning           1.5.10

Sample Code Explanation

The sample code is an AutoEncoder model provided by official document. If you want to implement other tasks, you can use the References at the bottom to go to the official website for searching.

There are three parts as below:

  • Import the packages
  • Use LightingModule to build the model
  • Instantiate the model and train, validate (we have no test dataset)


Import the packages

Of course, the torch and pytorch_lightning package are important; And we want to test MNIST recognition of handwritten dataset, so we need to import torchvision too.

If it is not installed in the environment, it also needs to be installed through pip.

import torch
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST


Use LightingModule to build model

The model we build is a standard python class, similar to, but different from, the models built by pytorch. Here we have established the following initialization and methods.

  • __init__(): create encoder and decoder structure
  • forward(): forward propagation (no return loss)
  • configure_optimizers(): create optimizer
  • training_step(): Encapsulated iterative training method (must return loss)
  • validation_step(): Iterative verification method after encapsulation

I was puzzled at first why pytorch's classic forward() and pytorch-ligjhtning encapsulated training_step() should be written together, but later I found that the official documentation clearly states that it is recommended to write separately for training and inference.

It makes sense. training_step() also needs to return loss, it means that this is a complete process of training, and we still need to use methods such as forward() if we want to predict unseen data.


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3),
        )

        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28),
        )
    
    def forward(self, x):
        return self.encoder(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)

        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("val_loss", loss)



Instantiate the model

  • First build MNIST dataset
  • Use random_split() to split training dataset and validation dataset
  • Use LitAutoEncoder() to instantiate model and train it

It should be noted that remember to set the max_epochs parameter for the first trial run. Because the default is 1000, if you use CPU, it will probably run for a while.

def main():
    # Data
    dataset = MNIST("", train=True, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [50000, 10000])

    train_loader = DataLoader(mnist_train, batch_size=32)
    val_loader = DataLoader(mnist_val, batch_size=32)

    # Model
    model = LitAutoEncoder()

    # Training
    trainer = pl.Trainer(gpus=0, precision=16, limit_train_batches=0.5, max_epochs=50)
    trainer.fit(model, train_loader, val_loader)


if __name__ == "__main__":
    main()

Full Sample Code

If you want to test the model directly, you can refer to:

# coding: utf-8
"""
This is a simple script for pytorch-lightning testing.
"""
import torch
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST


class LitAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3),
        )

        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28),
        )
    
    def forward(self, x):
        return self.encoder(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)

        return loss
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("val_loss", loss)


def main():
    # Data
    dataset = MNIST("", train=True, download=True, transform=transforms.ToTensor())
    mnist_train, mnist_val = random_split(dataset, [50000, 10000])

    train_loader = DataLoader(mnist_train, batch_size=32)
    val_loader = DataLoader(mnist_val, batch_size=32)

    # Model
    model = LitAutoEncoder()

    # Training
    trainer = pl.Trainer(gpus=0, precision=16, limit_train_batches=0.5, max_epochs=50)
    trainer.fit(model, train_loader, val_loader)


if __name__ == "__main__":
    main()

References


Read More

Leave a ReplyCancel reply

Exit mobile version