Skip to content

[PyTorch] 使用 ModuleList 減少重複定義模型的程式碼數量

Last Updated on 2021-07-21 by Clay

在我們使用 PyTorch 搭建模型來處理我們深度學習的任務的時候,有時候我們會面臨需要『重複定義不同模型層』的情況,有時候這是很讓人煩躁的,尤其是必須毫無必要地寫一大堆都是複製貼上的程式碼。

這種時候,在朋友的建議下,我發現可以使用 ModuleList 來直接使用 for 迴圈來快速定義不同的模型層 (唯一的需求就是模型層之間的神經元數目不能 Mismatch ... ) ,使用了幾次之後,我覺得像是發現了新天地一般,從前有很多時候比較多層數的模型可以說是白費了許多苦工 ......

那 ModuleList 究竟是什麼呢?ModuleList 跟 Sequential 不太一樣,Sequential 是直接將複數的模型層建立、輸入 Inputs 的值之後由上到下執行;而 ModuleList 僅僅就只是 Python 中的一個 List 資料型態,只是將模型層建立起來,實際使用還是要在 forward() 的區塊進行定義。

直接敘述可能很難明白其中的差異以及 ModuleList 的好處,以下直接就兩個例子進行示範。


ModuleList 的使用方法

首先,我們先來看一個簡單、單純、但寫起來卻很枯燥的模型。

# coding: utf-8
import torch
import torch.nn as nn


class common_model(nn.Module):
    def __init__(self):
        super(common_model, self).__init__()
        self.fc_1 = nn.Linear(700, 600)
        self.fc_2 = nn.Linear(600, 500)
        self.fc_3 = nn.Linear(500, 400)
        self.fc_4 = nn.Linear(400, 300)
        self.fc_5 = nn.Linear(300, 200)
        self.fc_6 = nn.Linear(200, 100)
        self.fc_7 = nn.Linear(100, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        outputs = self.fc_1(inputs)
        outputs = self.fc_2(outputs)
        outputs = self.fc_3(outputs)
        outputs = self.fc_4(outputs)
        outputs = self.fc_5(outputs)
        outputs = self.fc_6(outputs)
        outputs = self.fc_7(outputs)

        return self.sigmoid(outputs)



Output:

common_model(
(fc_1): Linear(in_features=700, out_features=600, bias=True)
(fc_2): Linear(in_features=600, out_features=500, bias=True)
(fc_3): Linear(in_features=500, out_features=400, bias=True)
(fc_4): Linear(in_features=400, out_features=300, bias=True)
(fc_5): Linear(in_features=300, out_features=200, bias=True)
(fc_6): Linear(in_features=200, out_features=100, bias=True)
(fc_7): Linear(in_features=100, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
torch.Size([1])

這個模型難不難寫呢?大家一定都覺得相當簡單,可是寫起來麻不麻煩呢?我只能肯定我要複製貼上好幾次,然後再改動其中好幾個地方。

但是倘若我們使用 ModuleList 來儲存模型,我們能很快地利用 for 迴圈建立模型層。

# coding: utf-8
import torch
import torch.nn as nn


class module_list_model(nn.Module):
    def __init__(self):
        super(module_list_model, self).__init__()

        self.fc = nn.ModuleList(
            [nn.Linear(d*100, (d-1)*100) for d in range(2, 8).__reversed__()]
        )

        self.fc_final = nn.Linear(100, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        for fc in self.fc:
            inputs = fc(inputs)

        outputs = self.fc_final(inputs)

        return self.sigmoid(outputs)


if __name__ == '__main__':
    inputs = torch.rand([700])
    model = module_list_model()

    outputs = model(inputs)
    print(model)
    print(outputs.shape)



Output:

module_list_model(
(fc): ModuleList(
(0): Linear(in_features=700, out_features=600, bias=True)
(1): Linear(in_features=600, out_features=500, bias=True)
(2): Linear(in_features=500, out_features=400, bias=True)
(3): Linear(in_features=400, out_features=300, bias=True)
(4): Linear(in_features=300, out_features=200, bias=True)
(5): Linear(in_features=200, out_features=100, bias=True)
)
(fc_final): Linear(in_features=100, out_features=1, bias=True)
(sigmoid): Sigmoid()
)
torch.Size([1])

這樣寫起來是不是比較舒服呢?不過要小心的是,儲存在 ModuleList 內的模型層還是要在 forward() 區塊中主動使用,不像 Sequential 那般可以直接輸入。當然,好處就是 ModuleList 常常寫起來比較快速,這點讓人滿意。


References


Read More

Leave a Reply