Skip to content

[PyTorch] Using "torchsummary" to plot your model structure

Last Updated on 2021-07-05 by Clay

When we using the famous Python framework: PyTorch, to build our model, if we can visualize our model, that's a cool idea.

In this way, we can check our model layer, output shape, and avoid our model mismatch.

Um...... it's more convenient for reporting.

PyTorch already has the function of "printing the model", of course it does. but the ploting is not follow the "forward()", just only the model layer we defined. It's a pity.

So, today I want to note a package which is specifically designed to plot the "forward()" structure in PyTorch: "torchsummary".

However, although I called it is visaulization, in fact it is currently only using the text of the command line to show the model structure. If you want to visualize the model like a flow chart, maybe you have to study TensorBoard.

Ok, let's take a look for how to use torchsummary.


"print(model)" in PyTorch

First, let's use the CNN classification model I wrote before to demonstrate the effect of PyTorch's original printed model.

# -*- coding: utf-8 -*-
"""
Defined CNN model
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


# Model
class CNN(nn.Module):
    def __init__(self, classes):
        super(CNN, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=1,
                              out_channels=16,
                              kernel_size=5,
                              stride=1,
                              padding=0)

        self.conv_2 = nn.Conv2d(in_channels=16,
                                out_channels=32,
                                kernel_size=5,
                                stride=1,
                                padding=0)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=2)
        self.fc = nn.Linear(32*4*4, classes)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv_2(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


if __name__ == '__main__':
    cnn = CNN(3000).cuda()
    print(cnn)



Output:

CNN(
  (conv_1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (conv_2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (relu): ReLU()
  (max_pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Linear(in_features=512, out_features=3000, bias=True)
)


I think we can see a problem from this side.

I have used ReLU() and MaxPool twice, but the model structure shown is just printed according to the initialized model layer.

What I want to see is the model printed according to the "forward()" part.


torchsummary

If we are use it in the first time, we need to install it with the following instructions.

sudo pip3 install torchsummary

The method of use is very simple, basically as follows:

# -*- coding: utf-8 -*-
"""
Defined CNN model
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary


# Model
class CNN(nn.Module):
    def __init__(self, classes):
        super(CNN, self).__init__()
        self.conv_1 = nn.Conv2d(in_channels=1,
                              out_channels=16,
                              kernel_size=5,
                              stride=1,
                              padding=0)

        self.conv_2 = nn.Conv2d(in_channels=16,
                                out_channels=32,
                                kernel_size=5,
                                stride=1,
                                padding=0)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=2)
        self.fc = nn.Linear(32*4*4, classes)

    def forward(self, x):
        x = self.conv_1(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = self.conv_2(x)
        x = self.relu(x)
        x = self.max_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


if __name__ == '__main__':
    cnn = CNN(3000).cuda()
    summary(cnn, (1, 28, 28))



Output:

This is the effect I want.

Not only is it printed out according to the model layer passed by Input, but also the Shape when passing through the model layer, which is exactly the effect I want.

It should be noted that when we use the summary() function, we must enter the shape of our Tensor and move the model to the GPU using cuda() for operation, so that torchsummary will work normally.

If the wrong Shape is entered, it will be reported directly!


Multi-input

torchsummary can handle more than just a single input. In fact, when our model is divided into two categories, with different inputs, and finally connected together, torchsummary can also handle it, but it is just not intuitive.

The following is an example on Github.

import torch
import torch.nn as nn
from torchsummary import summary

class SimpleConv(nn.Module):
    def __init__(self):
        super(SimpleConv, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
        )

    def forward(self, x, y):
        x1 = self.features(x)
        x2 = self.features(y)
        return x1, x2
   
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleConv().to(device)

summary(model, [(1, 16, 16), (1, 28, 28)])



Output:


Experience

This is what I have been looking for. It is indeed the function I want. But people's hearts are always inadequate. After discovering this tool, I hope to visualize it like a flowchart. I wanted to say that I might be able to look at the visual program myself, but after seeing the results of multiple inputs, I realized that things are not as simple as I thought.

And before you do it yourself, it is best to check if there is a great god already made. Generally speaking, the kits made by the Great God are always stable. Furthermore, if you repeatedly make wheels and have no room for it, it is a bit of a waste of time.

In fact, using TensorBoard can produce a visual model structure diagram similar to what I want. It may not really be necessary to try to write a model flow chart in order to save such a little time (the time connected to TensorBoard).

In short, torchsummary is a pretty good kit, recommended to everyone.


References


Read More

Leave a ReplyCancel reply

Exit mobile version