When we using PyTorch to build the model for deep learning tasks, sometimes we need to define more and more model layer.
It is so irritating. No one wants to keep pasting similar code over and over again.
A friend suggest me to use ModuleList
to use for-loop and define different model layers, the only requirement is that the number of neurons between the model layers cannot be mismatch.
So what is ModuleList
? ModuleList is not the same as Sequential
. Sequential
creates a complex model layer, inputs the value and executes it from top to bottom;
But ModuleList
is just a List
data type in python, which just builds the model layer. The use still needs to be defined in the forward()
block.
Two examples are demo below.
How To Use ModuleList To Define Model layer
First, let's take a look for a simple model.
# coding: utf-8 import torch import torch.nn as nn class common_model(nn.Module): def __init__(self): super(common_model, self).__init__() self.fc_1 = nn.Linear(700, 600) self.fc_2 = nn.Linear(600, 500) self.fc_3 = nn.Linear(500, 400) self.fc_4 = nn.Linear(400, 300) self.fc_5 = nn.Linear(300, 200) self.fc_6 = nn.Linear(200, 100) self.fc_7 = nn.Linear(100, 1) self.sigmoid = nn.Sigmoid() def forward(self, inputs): outputs = self.fc_1(inputs) outputs = self.fc_2(outputs) outputs = self.fc_3(outputs) outputs = self.fc_4(outputs) outputs = self.fc_5(outputs) outputs = self.fc_6(outputs) outputs = self.fc_7(outputs) return self.sigmoid(outputs)
Output:
common_model(
(fc_1): Linear(in_features=700, out_features=600, bias=True)
(fc_2): Linear(in_features=600, out_features=500, bias=True)
(fc_3): Linear(in_features=500, out_features=400, bias=True)
(fc_4): Linear(in_features=400, out_features=300, bias=True)
(fc_5): Linear(in_features=300, out_features=200, bias=True)
(fc_6): Linear(in_features=200, out_features=100, bias=True)
(fc_7): Linear(in_features=100, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
torch.Size([1])
Is this model difficult to write? Every must think it's quite simple.
But is it cumbersome to write? I can only be sure that I need to copy and paste it several times, and then change several of them.
But if we use ModuleList to create the model, we only need one line of code to complete it.
# coding: utf-8 import torch import torch.nn as nn class module_list_model(nn.Module): def __init__(self): super(module_list_model, self).__init__() self.fc = nn.ModuleList( [nn.Linear(d*100, (d-1)*100) for d in range(2, 8).__reversed__()] ) self.fc_final = nn.Linear(100, 1) self.sigmoid = nn.Sigmoid() def forward(self, inputs): for fc in self.fc: inputs = fc(inputs) outputs = self.fc_final(inputs) return self.sigmoid(outputs) if __name__ == '__main__': inputs = torch.rand([700]) model = module_list_model() outputs = model(inputs) print(model) print(outputs.shape)
Output:
module_list_model(
(fc): ModuleList(
(0): Linear(in_features=700, out_features=600, bias=True)
(1): Linear(in_features=600, out_features=500, bias=True)
(2): Linear(in_features=500, out_features=400, bias=True)
(3): Linear(in_features=400, out_features=300, bias=True)
(4): Linear(in_features=300, out_features=200, bias=True)
(5): Linear(in_features=200, out_features=100, bias=True)
)
(fc_final): Linear(in_features=100, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
torch.Size([1])
It it more comfortable?
But be careful, the model layer stored in the ModuleList still has to be actively used in the forward()
block, unlike Sequential
that can be directly input.
References
- https://pytorch.org/docs/stable/nn.html
- https://www.programcreek.com/python/example/107669/torch.nn.ModuleList
- https://h1ros.github.io/posts/3-ways-of-creating-a-neural-network-in-pytorch/
Read More
- [PyTorch] Use torch.cat() To Replace The append() Operation In The List Data When Processing torch Tensor
- [PyTorch] How To Print Model Architecture And Extract Model Weights
- [PyTorch] LSTM Principle and Input and Output Format Record
- [PyTorch] Use "Embedding" Layer To Process Text
- [PyTorch] Give Different Loss Weights for Different Classification results to Solve the Problem of Data Imbalance