Last Updated on 2022-07-23 by Clay
PyTorch Lightning is a framework that encapsulates native PyTorch in a more advanced level, just like Keras does to Tensorflow (although I remember a lot of backends that Keras can support).
To put it simply, many people think that some PyTorch operations are too low-level, for example, to write for-loop iterative training, manually clear the accumulated gradient in it, and manually do backward propagation... and many more.
Of course these low-level operations are very important! For example, we might want to iterate a few more times before doing the back propagation. But, simple is good.
The LightModule
provided by PyTorch Lightning is actually a further encapsulated torch.nn.Module
.
According to the instructions of official documents, LightModule is to encapsulate the following training codes that need to be written:
- Train loop
- Validation loop
- Test loop
- Prediction loop
- The model or system of models (I am not sure what this part is encapsulating, it should be that I have not used it yet)
- Optimizer and learning rate
Below, I will briefly record how to use PyTorch Lightning, and attach an official sample code and try to explain it.
How to use
First is the installation part. We can use pip
command to install pytorch-lightning. Of course, I recommend install at a Python virtual environment.
pip3 install pytorch-lightning
And you can check it has been existed.
pip3 list | grep pytorch-lightning
Check the version:
pytorch-lightning 1.5.10
Sample Code Explanation
The sample code is an AutoEncoder model provided by official document. If you want to implement other tasks, you can use the References at the bottom to go to the official website for searching.
There are three parts as below:
- Import the packages
- Use LightingModule to build the model
- Instantiate the model and train, validate (we have no test dataset)
Import the packages
Of course, the torch
and pytorch_lightning
package are important; And we want to test MNIST recognition of handwritten dataset, so we need to import torchvision
too.
If it is not installed in the environment, it also needs to be installed through pip
.
import torch
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST
Use LightingModule to build model
The model we build is a standard python class, similar to, but different from, the models built by pytorch. Here we have established the following initialization and methods.
__init__()
: create encoder and decoder structureforward()
: forward propagation (no return loss)configure_optimizers()
: create optimizertraining_step()
: Encapsulated iterative training method (must return loss)validation_step()
: Iterative verification method after encapsulation
I was puzzled at first why pytorch's classic forward()
and pytorch-ligjhtning encapsulated training_step()
should be written together, but later I found that the official documentation clearly states that it is recommended to write separately for training and inference.
It makes sense. training_step()
also needs to return loss, it means that this is a complete process of training, and we still need to use methods such as forward()
if we want to predict unseen data.
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 64),
nn.ReLU(),
nn.Linear(64, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28*28),
)
def forward(self, x):
return self.encoder(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("val_loss", loss)
Instantiate the model
- First build MNIST dataset
- Use
random_split()
to split training dataset and validation dataset - Use
LitAutoEncoder()
to instantiate model and train it
It should be noted that remember to set the max_epochs
parameter for the first trial run. Because the default is 1000, if you use CPU, it will probably run for a while.
def main():
# Data
dataset = MNIST("", train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [50000, 10000])
train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)
# Model
model = LitAutoEncoder()
# Training
trainer = pl.Trainer(gpus=0, precision=16, limit_train_batches=0.5, max_epochs=50)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
main()
Full Sample Code
If you want to test the model directly, you can refer to:
# coding: utf-8
"""
This is a simple script for pytorch-lightning testing.
"""
import torch
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms
from torchvision.datasets import MNIST
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28*28, 64),
nn.ReLU(),
nn.Linear(64, 3),
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28*28),
)
def forward(self, x):
return self.encoder(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("val_loss", loss)
def main():
# Data
dataset = MNIST("", train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [50000, 10000])
train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)
# Model
model = LitAutoEncoder()
# Training
trainer = pl.Trainer(gpus=0, precision=16, limit_train_batches=0.5, max_epochs=50)
trainer.fit(model, train_loader, val_loader)
if __name__ == "__main__":
main()