Skip to content

[PyTorch] 遍歷模型每一層神經網路筆記

介紹

最近在因緣際會下,有了簡單改一些模型架構的機會,於是便趁這個機會好好地摸索了下 PyTorch 中模型神經網路層的遍歷打印方式。

我想大家應該都知道,如果我們 print(model) 時,我們看到的並不是模型的 forward() 中所走的路線,而是模型在類別 __init__() 中定義的先後順序。

比方說以下程式:

# coding: utf-8
import torch


class CustomModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 20)
        self.fc2 = torch.nn.Linear(30, 40)
        self.fc3 = torch.nn.Linear(20, 30)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.fc1(inputs)
        outputs = self.fc3(outputs)
        return self.fc2(outputs)


if __name__ == "__main__":
    model = CustomModel()
    print(model)



Output:

CustomModel(
  (fc1): Linear(in_features=10, out_features=20, bias=True)
  (fc2): Linear(in_features=30, out_features=40, bias=True)
  (fc3): Linear(in_features=20, out_features=30, bias=True)
)

我們可以看到,明明我們模型在 forward() 中使用的順序是 1 > 3 > 2 —— 但是印出模型架構時還是按照 1 > 2 > 3 的定義順序。

所以,知道模型定義的神經網路層後,我們仍然不能夠確定模型實際輸入資料跑的時候的處理模型,但是知道模型定義的層有哪些時,就足以做一些最基礎的改動了。

比方說添加適配器(adapter)、比方說增加參數、減少參數…… 等等,都可以透過神經網路層名稱而進行直接的改動。

本篇筆記就不往下展開如何修改模型層了,而是專注於該如何遍歷取得模型層的名稱和權重。

基本上可以分成三種方式:

  • children() or named_children()
  • modules() or named_modules()
  • parameters() or named_parameters()

三種方法都適用不同的情況,大家可以根據自己的需求來決定要使用哪個。


使用 bert-tiny 作為範例

在這裡,我們使用一個很小的 BERT 模型來做範例。

# coding: utf-8
from transformers import AutoModelForSequenceClassification


def main():
    model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
    print(model)


if __name__ == "__main__":
    main()


Output:

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=128, out_features=512, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=128, out_features=2, bias=True)
)

以上就是 bert-tiny 所定義的模型層,我們可以發現雖然這是一個小到不能再小的模型,但是其中定義的模型層還是挺多的。

以下我們都拿這個模型當作範例。


children() or named_children()

children()named_children() 的差別只在於一個返回子模塊(sub-module)的迭代器,一個是多返回了名稱。

# coding: utf-8
from transformers import AutoModelForSequenceClassification


def main():
    model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
    for name, module in model.named_children():
        print(name, module)


if __name__ == "__main__":
    main()


Output:

bert BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-1): 2 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=128, out_features=128, bias=True)
            (key): Linear(in_features=128, out_features=128, bias=True)
            (value): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=128, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=128, out_features=512, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=512, out_features=128, bias=True)
          (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=128, out_features=128, bias=True)
    (activation): Tanh()
  )
)
dropout Dropout(p=0.1, inplace=False)
classifier Linear(in_features=128, out_features=2, bias=True)

modules() or named_modules()

# coding: utf-8
from transformers import AutoModelForSequenceClassification


def main():
    model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
    for name, module in model.named_modules():
        print(name, module)


if __name__ == "__main__":
    main()


Output:

 BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=128, out_features=512, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=128, out_features=2, bias=True)
)
bert BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-1): 2 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=128, out_features=128, bias=True)
            (key): Linear(in_features=128, out_features=128, bias=True)
            (value): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=128, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=128, out_features=512, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=512, out_features=128, bias=True)
          (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=128, out_features=128, bias=True)
    (activation): Tanh()
  )
)
bert.embeddings BertEmbeddings(
  (word_embeddings): Embedding(30522, 128, padding_idx=0)
  (position_embeddings): Embedding(512, 128)
  (token_type_embeddings): Embedding(2, 128)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.embeddings.word_embeddings Embedding(30522, 128, padding_idx=0)
bert.embeddings.position_embeddings Embedding(512, 128)
bert.embeddings.token_type_embeddings Embedding(2, 128)
bert.embeddings.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.embeddings.dropout Dropout(p=0.1, inplace=False)
bert.encoder BertEncoder(
  (layer): ModuleList(
    (0-1): 2 x BertLayer(
      (attention): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(in_features=128, out_features=128, bias=True)
          (key): Linear(in_features=128, out_features=128, bias=True)
          (value): Linear(in_features=128, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): BertSelfOutput(
          (dense): Linear(in_features=128, out_features=128, bias=True)
          (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (intermediate): BertIntermediate(
        (dense): Linear(in_features=128, out_features=512, bias=True)
        (intermediate_act_fn): GELUActivation()
      )
      (output): BertOutput(
        (dense): Linear(in_features=512, out_features=128, bias=True)
        (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
)
bert.encoder.layer ModuleList(
  (0-1): 2 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=128, out_features=128, bias=True)
        (key): Linear(in_features=128, out_features=128, bias=True)
        (value): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=128, out_features=128, bias=True)
        (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=128, out_features=512, bias=True)
      (intermediate_act_fn): GELUActivation()
    )
    (output): BertOutput(
      (dense): Linear(in_features=512, out_features=128, bias=True)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
bert.encoder.layer.0 BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=128, out_features=128, bias=True)
      (key): Linear(in_features=128, out_features=128, bias=True)
      (value): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=128, out_features=512, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=512, out_features=128, bias=True)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
bert.encoder.layer.0.attention BertAttention(
  (self): BertSelfAttention(
    (query): Linear(in_features=128, out_features=128, bias=True)
    (key): Linear(in_features=128, out_features=128, bias=True)
    (value): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=128, out_features=128, bias=True)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
bert.encoder.layer.0.attention.self BertSelfAttention(
  (query): Linear(in_features=128, out_features=128, bias=True)
  (key): Linear(in_features=128, out_features=128, bias=True)
  (value): Linear(in_features=128, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.attention.self.query Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.key Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.value Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.0.attention.output BertSelfOutput(
  (dense): Linear(in_features=128, out_features=128, bias=True)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.attention.output.dense Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.0.attention.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.0.intermediate BertIntermediate(
  (dense): Linear(in_features=128, out_features=512, bias=True)
  (intermediate_act_fn): GELUActivation()
)
bert.encoder.layer.0.intermediate.dense Linear(in_features=128, out_features=512, bias=True)
bert.encoder.layer.0.intermediate.intermediate_act_fn GELUActivation()
bert.encoder.layer.0.output BertOutput(
  (dense): Linear(in_features=512, out_features=128, bias=True)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.output.dense Linear(in_features=512, out_features=128, bias=True)
bert.encoder.layer.0.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.0.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1 BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=128, out_features=128, bias=True)
      (key): Linear(in_features=128, out_features=128, bias=True)
      (value): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=128, out_features=512, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=512, out_features=128, bias=True)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
bert.encoder.layer.1.attention BertAttention(
  (self): BertSelfAttention(
    (query): Linear(in_features=128, out_features=128, bias=True)
    (key): Linear(in_features=128, out_features=128, bias=True)
    (value): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=128, out_features=128, bias=True)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
bert.encoder.layer.1.attention.self BertSelfAttention(
  (query): Linear(in_features=128, out_features=128, bias=True)
  (key): Linear(in_features=128, out_features=128, bias=True)
  (value): Linear(in_features=128, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.attention.self.query Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.key Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.value Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1.attention.output BertSelfOutput(
  (dense): Linear(in_features=128, out_features=128, bias=True)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.attention.output.dense Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.1.attention.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1.intermediate BertIntermediate(
  (dense): Linear(in_features=128, out_features=512, bias=True)
  (intermediate_act_fn): GELUActivation()
)
bert.encoder.layer.1.intermediate.dense Linear(in_features=128, out_features=512, bias=True)
bert.encoder.layer.1.intermediate.intermediate_act_fn GELUActivation()
bert.encoder.layer.1.output BertOutput(
  (dense): Linear(in_features=512, out_features=128, bias=True)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.output.dense Linear(in_features=512, out_features=128, bias=True)
bert.encoder.layer.1.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.1.output.dropout Dropout(p=0.1, inplace=False)
bert.pooler BertPooler(
  (dense): Linear(in_features=128, out_features=128, bias=True)
  (activation): Tanh()
)
bert.pooler.dense Linear(in_features=128, out_features=128, bias=True)
bert.pooler.activation Tanh()
dropout Dropout(p=0.1, inplace=False)
classifier Linear(in_features=128, out_features=2, bias=True)

直接印出 module 會看起來特別長,這是因為它確實把每個 modules() 都展開並印出來,相比之下 children() 不會有重複印的問題,只會印最外層的,不會像 modules() 這樣每一層『裡面的那一層』也都會跟著再印一次。

但真的要修改模型時,使用 modules() 來看會很方便,畢竟可以直接對應名稱和模型層。


parameters() or named_parameters()

parameters() 返回模型中所有可學習參數。如果只想要看模型的參數模型,直接使用 parameters() 即可。

順帶一提,平常我們丟進優化器(optimizer)的就是這個 method。

# coding: utf-8
from transformers import AutoModelForSequenceClassification


def main():
    model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
    for name, param in model.named_parameters():
        print(name, param)


if __name__ == "__main__":
    main()



印出的結果我就不放了。因為會看到一堆的權重值,看了恐怕會眼花撩亂。



簡單來說:

  • children() 會遍歷子模塊
  • modules() 會遍歷子模塊的子模塊… 遍歷到結束為止
  • parameters() 則是印出可學習參數

References


Read More

Leave a Reply