Last Updated on 2023-09-12 by Clay
介紹
最近在因緣際會下,有了簡單改一些模型架構的機會,於是便趁這個機會好好地摸索了下 PyTorch 中模型神經網路層的遍歷打印方式。
我想大家應該都知道,如果我們 print(model) 時,我們看到的並不是模型的 forward() 中所走的路線,而是模型在類別 __init__() 中定義的先後順序。
比方說以下程式:
# coding: utf-8
import torch
class CustomModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(10, 20)
self.fc2 = torch.nn.Linear(30, 40)
self.fc3 = torch.nn.Linear(20, 30)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.fc1(inputs)
outputs = self.fc3(outputs)
return self.fc2(outputs)
if __name__ == "__main__":
model = CustomModel()
print(model)
Output:
CustomModel( (fc1): Linear(in_features=10, out_features=20, bias=True) (fc2): Linear(in_features=30, out_features=40, bias=True) (fc3): Linear(in_features=20, out_features=30, bias=True) )
我們可以看到,明明我們模型在 forward() 中使用的順序是 1 > 3 > 2 —— 但是印出模型架構時還是按照 1 > 2 > 3 的定義順序。
所以,知道模型定義的神經網路層後,我們仍然不能夠確定模型實際輸入資料跑的時候的處理模型,但是知道模型定義的層有哪些時,就足以做一些最基礎的改動了。
比方說添加適配器(adapter)、比方說增加參數、減少參數…… 等等,都可以透過神經網路層名稱而進行直接的改動。
本篇筆記就不往下展開如何修改模型層了,而是專注於該如何遍歷取得模型層的名稱和權重。
基本上可以分成三種方式:
children()ornamed_children()modules()ornamed_modules()parameters()ornamed_parameters()
三種方法都適用不同的情況,大家可以根據自己的需求來決定要使用哪個。
使用 bert-tiny 作為範例
在這裡,我們使用一個很小的 BERT 模型來做範例。
# coding: utf-8
from transformers import AutoModelForSequenceClassification
def main():
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
print(model)
if __name__ == "__main__":
main()
Output:
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 128, padding_idx=0)
(position_embeddings): Embedding(512, 128)
(token_type_embeddings): Embedding(2, 128)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-1): 2 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=128, out_features=128, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=128, out_features=2, bias=True)
)
以上就是 bert-tiny 所定義的模型層,我們可以發現雖然這是一個小到不能再小的模型,但是其中定義的模型層還是挺多的。
以下我們都拿這個模型當作範例。
children() or named_children()
children() 和 named_children() 的差別只在於一個返回子模塊(sub-module)的迭代器,一個是多返回了名稱。
# coding: utf-8
from transformers import AutoModelForSequenceClassification
def main():
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
for name, module in model.named_children():
print(name, module)
if __name__ == "__main__":
main()
Output:
bert BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 128, padding_idx=0)
(position_embeddings): Embedding(512, 128)
(token_type_embeddings): Embedding(2, 128)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-1): 2 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=128, out_features=128, bias=True)
(activation): Tanh()
)
)
dropout Dropout(p=0.1, inplace=False)
classifier Linear(in_features=128, out_features=2, bias=True)
modules() or named_modules()
# coding: utf-8
from transformers import AutoModelForSequenceClassification
def main():
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
for name, module in model.named_modules():
print(name, module)
if __name__ == "__main__":
main()
Output:
BertForSequenceClassification(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 128, padding_idx=0)
(position_embeddings): Embedding(512, 128)
(token_type_embeddings): Embedding(2, 128)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-1): 2 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=128, out_features=128, bias=True)
(activation): Tanh()
)
)
(dropout): Dropout(p=0.1, inplace=False)
(classifier): Linear(in_features=128, out_features=2, bias=True)
)
bert BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 128, padding_idx=0)
(position_embeddings): Embedding(512, 128)
(token_type_embeddings): Embedding(2, 128)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-1): 2 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=128, out_features=128, bias=True)
(activation): Tanh()
)
)
bert.embeddings BertEmbeddings(
(word_embeddings): Embedding(30522, 128, padding_idx=0)
(position_embeddings): Embedding(512, 128)
(token_type_embeddings): Embedding(2, 128)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
bert.embeddings.word_embeddings Embedding(30522, 128, padding_idx=0)
bert.embeddings.position_embeddings Embedding(512, 128)
bert.embeddings.token_type_embeddings Embedding(2, 128)
bert.embeddings.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.embeddings.dropout Dropout(p=0.1, inplace=False)
bert.encoder BertEncoder(
(layer): ModuleList(
(0-1): 2 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
bert.encoder.layer ModuleList(
(0-1): 2 x BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
bert.encoder.layer.0 BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
bert.encoder.layer.0.attention BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
bert.encoder.layer.0.attention.self BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.attention.self.query Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.key Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.value Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.0.attention.output BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.attention.output.dense Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.0.attention.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.0.intermediate BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
bert.encoder.layer.0.intermediate.dense Linear(in_features=128, out_features=512, bias=True)
bert.encoder.layer.0.intermediate.intermediate_act_fn GELUActivation()
bert.encoder.layer.0.output BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.output.dense Linear(in_features=512, out_features=128, bias=True)
bert.encoder.layer.0.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.0.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1 BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
bert.encoder.layer.1.attention BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
bert.encoder.layer.1.attention.self BertSelfAttention(
(query): Linear(in_features=128, out_features=128, bias=True)
(key): Linear(in_features=128, out_features=128, bias=True)
(value): Linear(in_features=128, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.attention.self.query Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.key Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.value Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1.attention.output BertSelfOutput(
(dense): Linear(in_features=128, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.attention.output.dense Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.1.attention.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1.intermediate BertIntermediate(
(dense): Linear(in_features=128, out_features=512, bias=True)
(intermediate_act_fn): GELUActivation()
)
bert.encoder.layer.1.intermediate.dense Linear(in_features=128, out_features=512, bias=True)
bert.encoder.layer.1.intermediate.intermediate_act_fn GELUActivation()
bert.encoder.layer.1.output BertOutput(
(dense): Linear(in_features=512, out_features=128, bias=True)
(LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.output.dense Linear(in_features=512, out_features=128, bias=True)
bert.encoder.layer.1.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.1.output.dropout Dropout(p=0.1, inplace=False)
bert.pooler BertPooler(
(dense): Linear(in_features=128, out_features=128, bias=True)
(activation): Tanh()
)
bert.pooler.dense Linear(in_features=128, out_features=128, bias=True)
bert.pooler.activation Tanh()
dropout Dropout(p=0.1, inplace=False)
classifier Linear(in_features=128, out_features=2, bias=True)
直接印出 module 會看起來特別長,這是因為它確實把每個 modules() 都展開並印出來,相比之下 children() 不會有重複印的問題,只會印最外層的,不會像 modules() 這樣每一層『裡面的那一層』也都會跟著再印一次。
但真的要修改模型時,使用 modules() 來看會很方便,畢竟可以直接對應名稱和模型層。
parameters() or named_parameters()
parameters() 返回模型中所有可學習參數。如果只想要看模型的參數模型,直接使用 parameters() 即可。
順帶一提,平常我們丟進優化器(optimizer)的就是這個 method。
# coding: utf-8
from transformers import AutoModelForSequenceClassification
def main():
model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
for name, param in model.named_parameters():
print(name, param)
if __name__ == "__main__":
main()
印出的結果我就不放了。因為會看到一堆的權重值,看了恐怕會眼花撩亂。
簡單來說:
- children() 會遍歷子模塊
- modules() 會遍歷子模塊的子模塊… 遍歷到結束為止
- parameters() 則是印出可學習參數
References
- How to iterate over layers in Pytorch – python
- PyTorch Get All Layers of Model: A Comprehensive Guide