Last Updated on 2023-09-12 by Clay
介紹
最近在因緣際會下,有了簡單改一些模型架構的機會,於是便趁這個機會好好地摸索了下 PyTorch 中模型神經網路層的遍歷打印方式。
我想大家應該都知道,如果我們 print(model)
時,我們看到的並不是模型的 forward()
中所走的路線,而是模型在類別 __init__()
中定義的先後順序。
比方說以下程式:
# coding: utf-8
import torch
class CustomModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 20)
self.fc2 = torch.nn.Linear(30, 40)
self.fc3 = torch.nn.Linear(20, 30)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.fc1(inputs)
outputs = self.fc3(outputs)
return self.fc2(outputs)
if __name__ == "__main__":
model = CustomModel()
print(model)
Output:
CustomModel( (fc1): Linear(in_features=10, out_features=20, bias=True) (fc2): Linear(in_features=30, out_features=40, bias=True) (fc3): Linear(in_features=20, out_features=30, bias=True) )
我們可以看到,明明我們模型在 forward()
中使用的順序是 1 > 3 > 2 —— 但是印出模型架構時還是按照 1 > 2 > 3 的定義順序。
所以,知道模型定義的神經網路層後,我們仍然不能夠確定模型實際輸入資料跑的時候的處理模型,但是知道模型定義的層有哪些時,就足以做一些最基礎的改動了。
比方說添加適配器(adapter)、比方說增加參數、減少參數...... 等等,都可以透過神經網路層名稱而進行直接的改動。
本篇筆記就不往下展開如何修改模型層了,而是專注於該如何遍歷取得模型層的名稱和權重。
基本上可以分成三種方式:
children()
ornamed_children()
modules()
ornamed_modules()
parameters()
ornamed_parameters()
三種方法都適用不同的情況,大家可以根據自己的需求來決定要使用哪個。
使用 bert-tiny 作為範例
在這裡,我們使用一個很小的 BERT 模型來做範例。
# coding: utf-8
from transformers import AutoModelForSequenceClassification
def main():
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
print(model)
if __name__ == "__main__":
main()
Output:
BertForSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(30522, 128, padding_idx=0) (position_embeddings): Embedding(512, 128) (token_type_embeddings): Embedding(2, 128) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0-1): 2 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=128, out_features=128, bias=True) (activation): Tanh() ) ) (dropout): Dropout(p=0.1, inplace=False) (classifier): Linear(in_features=128, out_features=2, bias=True) )
以上就是 bert-tiny 所定義的模型層,我們可以發現雖然這是一個小到不能再小的模型,但是其中定義的模型層還是挺多的。
以下我們都拿這個模型當作範例。
children()
or named_children()
children()
和 named_children()
的差別只在於一個返回子模塊(sub-module)的迭代器,一個是多返回了名稱。
# coding: utf-8
from transformers import AutoModelForSequenceClassification
def main():
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
for name, module in model.named_children():
print(name, module)
if __name__ == "__main__":
main()
Output:
bert BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(30522, 128, padding_idx=0) (position_embeddings): Embedding(512, 128) (token_type_embeddings): Embedding(2, 128) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0-1): 2 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=128, out_features=128, bias=True) (activation): Tanh() ) ) dropout Dropout(p=0.1, inplace=False) classifier Linear(in_features=128, out_features=2, bias=True)
modules()
or named_modules()
# coding: utf-8
from transformers import AutoModelForSequenceClassification
def main():
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
for name, module in model.named_modules():
print(name, module)
if __name__ == "__main__":
main()
Output:
BertForSequenceClassification( (bert): BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(30522, 128, padding_idx=0) (position_embeddings): Embedding(512, 128) (token_type_embeddings): Embedding(2, 128) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0-1): 2 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=128, out_features=128, bias=True) (activation): Tanh() ) ) (dropout): Dropout(p=0.1, inplace=False) (classifier): Linear(in_features=128, out_features=2, bias=True) ) bert BertModel( (embeddings): BertEmbeddings( (word_embeddings): Embedding(30522, 128, padding_idx=0) (position_embeddings): Embedding(512, 128) (token_type_embeddings): Embedding(2, 128) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) (encoder): BertEncoder( (layer): ModuleList( (0-1): 2 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) (pooler): BertPooler( (dense): Linear(in_features=128, out_features=128, bias=True) (activation): Tanh() ) ) bert.embeddings BertEmbeddings( (word_embeddings): Embedding(30522, 128, padding_idx=0) (position_embeddings): Embedding(512, 128) (token_type_embeddings): Embedding(2, 128) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) bert.embeddings.word_embeddings Embedding(30522, 128, padding_idx=0) bert.embeddings.position_embeddings Embedding(512, 128) bert.embeddings.token_type_embeddings Embedding(2, 128) bert.embeddings.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True) bert.embeddings.dropout Dropout(p=0.1, inplace=False) bert.encoder BertEncoder( (layer): ModuleList( (0-1): 2 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) ) bert.encoder.layer ModuleList( (0-1): 2 x BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) ) bert.encoder.layer.0 BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) bert.encoder.layer.0.attention BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) bert.encoder.layer.0.attention.self BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) bert.encoder.layer.0.attention.self.query Linear(in_features=128, out_features=128, bias=True) bert.encoder.layer.0.attention.self.key Linear(in_features=128, out_features=128, bias=True) bert.encoder.layer.0.attention.self.value Linear(in_features=128, out_features=128, bias=True) bert.encoder.layer.0.attention.self.dropout Dropout(p=0.1, inplace=False) bert.encoder.layer.0.attention.output BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) bert.encoder.layer.0.attention.output.dense Linear(in_features=128, out_features=128, bias=True) bert.encoder.layer.0.attention.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True) bert.encoder.layer.0.attention.output.dropout Dropout(p=0.1, inplace=False) bert.encoder.layer.0.intermediate BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) bert.encoder.layer.0.intermediate.dense Linear(in_features=128, out_features=512, bias=True) bert.encoder.layer.0.intermediate.intermediate_act_fn GELUActivation() bert.encoder.layer.0.output BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) bert.encoder.layer.0.output.dense Linear(in_features=512, out_features=128, bias=True) bert.encoder.layer.0.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True) bert.encoder.layer.0.output.dropout Dropout(p=0.1, inplace=False) bert.encoder.layer.1 BertLayer( (attention): BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) (intermediate): BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) (output): BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) bert.encoder.layer.1.attention BertAttention( (self): BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) (output): BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) ) bert.encoder.layer.1.attention.self BertSelfAttention( (query): Linear(in_features=128, out_features=128, bias=True) (key): Linear(in_features=128, out_features=128, bias=True) (value): Linear(in_features=128, out_features=128, bias=True) (dropout): Dropout(p=0.1, inplace=False) ) bert.encoder.layer.1.attention.self.query Linear(in_features=128, out_features=128, bias=True) bert.encoder.layer.1.attention.self.key Linear(in_features=128, out_features=128, bias=True) bert.encoder.layer.1.attention.self.value Linear(in_features=128, out_features=128, bias=True) bert.encoder.layer.1.attention.self.dropout Dropout(p=0.1, inplace=False) bert.encoder.layer.1.attention.output BertSelfOutput( (dense): Linear(in_features=128, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) bert.encoder.layer.1.attention.output.dense Linear(in_features=128, out_features=128, bias=True) bert.encoder.layer.1.attention.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True) bert.encoder.layer.1.attention.output.dropout Dropout(p=0.1, inplace=False) bert.encoder.layer.1.intermediate BertIntermediate( (dense): Linear(in_features=128, out_features=512, bias=True) (intermediate_act_fn): GELUActivation() ) bert.encoder.layer.1.intermediate.dense Linear(in_features=128, out_features=512, bias=True) bert.encoder.layer.1.intermediate.intermediate_act_fn GELUActivation() bert.encoder.layer.1.output BertOutput( (dense): Linear(in_features=512, out_features=128, bias=True) (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True) (dropout): Dropout(p=0.1, inplace=False) ) bert.encoder.layer.1.output.dense Linear(in_features=512, out_features=128, bias=True) bert.encoder.layer.1.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True) bert.encoder.layer.1.output.dropout Dropout(p=0.1, inplace=False) bert.pooler BertPooler( (dense): Linear(in_features=128, out_features=128, bias=True) (activation): Tanh() ) bert.pooler.dense Linear(in_features=128, out_features=128, bias=True) bert.pooler.activation Tanh() dropout Dropout(p=0.1, inplace=False) classifier Linear(in_features=128, out_features=2, bias=True)
直接印出 module 會看起來特別長,這是因為它確實把每個 modules()
都展開並印出來,相比之下 children()
不會有重複印的問題,只會印最外層的,不會像 modules()
這樣每一層『裡面的那一層』也都會跟著再印一次。
但真的要修改模型時,使用 modules()
來看會很方便,畢竟可以直接對應名稱和模型層。
parameters()
or named_parameters()
parameters()
返回模型中所有可學習參數。如果只想要看模型的參數模型,直接使用 parameters()
即可。
順帶一提,平常我們丟進優化器(optimizer)的就是這個 method。
# coding: utf-8
from transformers import AutoModelForSequenceClassification
def main():
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
for name, param in model.named_parameters():
print(name, param)
if __name__ == "__main__":
main()
印出的結果我就不放了。因為會看到一堆的權重值,看了恐怕會眼花撩亂。
簡單來說:
- children() 會遍歷子模塊
- modules() 會遍歷子模塊的子模塊... 遍歷到結束為止
- parameters() 則是印出可學習參數
References
- How to iterate over layers in Pytorch - python
- PyTorch Get All Layers of Model: A Comprehensive Guide