Skip to content

[PyTorch] Traversing Every Layer of a Neural Network in a Model

Last Updated on 2024-09-10 by Clay

Introduction

Recently, due to some serendipitous events, I had a chance to modify the architecture of a model slightly. I took this opportunity to explore how to iterate and print the layers of neural networks in PyTorch.

I believe many people know that when we use print(model), what we see is not the route taken in the forward() function, but rather the order defined in the __init__() method of the class.

For example, the following code:

# coding: utf-8
import torch


class CustomModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 20)
        self.fc2 = torch.nn.Linear(30, 40)
        self.fc3 = torch.nn.Linear(20, 30)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        outputs = self.fc1(inputs)
        outputs = self.fc3(outputs)
        return self.fc2(outputs)


if __name__ == "__main__":
    model = CustomModel()
    print(model)



Output:

CustomModel(
  (fc1): Linear(in_features=10, out_features=20, bias=True)
  (fc2): Linear(in_features=30, out_features=40, bias=True)
  (fc3): Linear(in_features=20, out_features=30, bias=True)
)

We can see that even though the order used in the forward() method is 1 > 3 > 2, the printed model architecture still follows the order defined as 1 > 2 > 3.

Thus, even if we know the defined layers of the model, we still can't be sure of the actual route the data will take when passed through the model during execution. However, knowing the defined layers is sufficient to make some basic modifications.

For example, adding an adapter, increasing parameters, or reducing parameters... All of these can be done by directly modifying the layers based on their names.

This note will not cover how to modify the model layers, but will focus on how to iterate and retrieve the layer names and weights of a model.

Basically, there are three methods:

  • children() or named_children()
  • modules() or named_modules()
  • parameters() or named_parameters()

These methods are suitable for different scenarios, and you can choose which one to use based on your needs.


Using bert-tiny as an Example

Here, we use a very small BERT model as an example.

# coding: utf-8
from transformers import AutoModelForSequenceClassification


def main():
    model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
    print(model)


if __name__ == "__main__":
    main()


Output:

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=128, out_features=512, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=128, out_features=2, bias=True)
)

These are the layers defined in the bert-tiny model. Even though it's a tiny model, there are still quite a few layers defined.

From here on, we will use this model as our example.


children() or named_children()

The difference between children() and named_children() is that one returns an iterator over sub-modules, while the other also returns their names.

# coding: utf-8
from transformers import AutoModelForSequenceClassification


def main():
    model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
    for name, module in model.named_children():
        print(name, module)


if __name__ == "__main__":
    main()


Output:

bert BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-1): 2 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=128, out_features=128, bias=True)
            (key): Linear(in_features=128, out_features=128, bias=True)
            (value): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=128, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=128, out_features=512, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=512, out_features=128, bias=True)
          (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=128, out_features=128, bias=True)
    (activation): Tanh()
  )
)
dropout Dropout(p=0.1, inplace=False)
classifier Linear(in_features=128, out_features=2, bias=True)

modules() or named_modules()

# coding: utf-8
from transformers import AutoModelForSequenceClassification


def main():
    model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
    for name, module in model.named_modules():
        print(name, module)


if __name__ == "__main__":
    main()


Output:

 BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=128, out_features=512, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (activation): Tanh()
    )
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (classifier): Linear(in_features=128, out_features=2, bias=True)
)
bert BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-1): 2 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=128, out_features=128, bias=True)
            (key): Linear(in_features=128, out_features=128, bias=True)
            (value): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=128, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=128, out_features=512, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=512, out_features=128, bias=True)
          (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=128, out_features=128, bias=True)
    (activation): Tanh()
  )
)
bert.embeddings BertEmbeddings(
  (word_embeddings): Embedding(30522, 128, padding_idx=0)
  (position_embeddings): Embedding(512, 128)
  (token_type_embeddings): Embedding(2, 128)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.embeddings.word_embeddings Embedding(30522, 128, padding_idx=0)
bert.embeddings.position_embeddings Embedding(512, 128)
bert.embeddings.token_type_embeddings Embedding(2, 128)
bert.embeddings.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.embeddings.dropout Dropout(p=0.1, inplace=False)
bert.encoder BertEncoder(
  (layer): ModuleList(
    (0-1): 2 x BertLayer(
      (attention): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(in_features=128, out_features=128, bias=True)
          (key): Linear(in_features=128, out_features=128, bias=True)
          (value): Linear(in_features=128, out_features=128, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): BertSelfOutput(
          (dense): Linear(in_features=128, out_features=128, bias=True)
          (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (intermediate): BertIntermediate(
        (dense): Linear(in_features=128, out_features=512, bias=True)
        (intermediate_act_fn): GELUActivation()
      )
      (output): BertOutput(
        (dense): Linear(in_features=512, out_features=128, bias=True)
        (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
)
bert.encoder.layer ModuleList(
  (0-1): 2 x BertLayer(
    (attention): BertAttention(
      (self): BertSelfAttention(
        (query): Linear(in_features=128, out_features=128, bias=True)
        (key): Linear(in_features=128, out_features=128, bias=True)
        (value): Linear(in_features=128, out_features=128, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (output): BertSelfOutput(
        (dense): Linear(in_features=128, out_features=128, bias=True)
        (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (intermediate): BertIntermediate(
      (dense): Linear(in_features=128, out_features=512, bias=True)
      (intermediate_act_fn): GELUActivation()
    )
    (output): BertOutput(
      (dense): Linear(in_features=512, out_features=128, bias=True)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)
bert.encoder.layer.0 BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=128, out_features=128, bias=True)
      (key): Linear(in_features=128, out_features=128, bias=True)
      (value): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=128, out_features=512, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=512, out_features=128, bias=True)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
bert.encoder.layer.0.attention BertAttention(
  (self): BertSelfAttention(
    (query): Linear(in_features=128, out_features=128, bias=True)
    (key): Linear(in_features=128, out_features=128, bias=True)
    (value): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=128, out_features=128, bias=True)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
bert.encoder.layer.0.attention.self BertSelfAttention(
  (query): Linear(in_features=128, out_features=128, bias=True)
  (key): Linear(in_features=128, out_features=128, bias=True)
  (value): Linear(in_features=128, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.attention.self.query Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.key Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.value Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.self.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.0.attention.output BertSelfOutput(
  (dense): Linear(in_features=128, out_features=128, bias=True)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.attention.output.dense Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.0.attention.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.0.attention.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.0.intermediate BertIntermediate(
  (dense): Linear(in_features=128, out_features=512, bias=True)
  (intermediate_act_fn): GELUActivation()
)
bert.encoder.layer.0.intermediate.dense Linear(in_features=128, out_features=512, bias=True)
bert.encoder.layer.0.intermediate.intermediate_act_fn GELUActivation()
bert.encoder.layer.0.output BertOutput(
  (dense): Linear(in_features=512, out_features=128, bias=True)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.0.output.dense Linear(in_features=512, out_features=128, bias=True)
bert.encoder.layer.0.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.0.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1 BertLayer(
  (attention): BertAttention(
    (self): BertSelfAttention(
      (query): Linear(in_features=128, out_features=128, bias=True)
      (key): Linear(in_features=128, out_features=128, bias=True)
      (value): Linear(in_features=128, out_features=128, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=128, out_features=128, bias=True)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=128, out_features=512, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=512, out_features=128, bias=True)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
bert.encoder.layer.1.attention BertAttention(
  (self): BertSelfAttention(
    (query): Linear(in_features=128, out_features=128, bias=True)
    (key): Linear(in_features=128, out_features=128, bias=True)
    (value): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=128, out_features=128, bias=True)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)
bert.encoder.layer.1.attention.self BertSelfAttention(
  (query): Linear(in_features=128, out_features=128, bias=True)
  (key): Linear(in_features=128, out_features=128, bias=True)
  (value): Linear(in_features=128, out_features=128, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.attention.self.query Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.key Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.value Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.self.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1.attention.output BertSelfOutput(
  (dense): Linear(in_features=128, out_features=128, bias=True)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.attention.output.dense Linear(in_features=128, out_features=128, bias=True)
bert.encoder.layer.1.attention.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.1.attention.output.dropout Dropout(p=0.1, inplace=False)
bert.encoder.layer.1.intermediate BertIntermediate(
  (dense): Linear(in_features=128, out_features=512, bias=True)
  (intermediate_act_fn): GELUActivation()
)
bert.encoder.layer.1.intermediate.dense Linear(in_features=128, out_features=512, bias=True)
bert.encoder.layer.1.intermediate.intermediate_act_fn GELUActivation()
bert.encoder.layer.1.output BertOutput(
  (dense): Linear(in_features=512, out_features=128, bias=True)
  (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)
bert.encoder.layer.1.output.dense Linear(in_features=512, out_features=128, bias=True)
bert.encoder.layer.1.output.LayerNorm LayerNorm((128,), eps=1e-12, elementwise_affine=True)
bert.encoder.layer.1.output.dropout Dropout(p=0.1, inplace=False)
bert.pooler BertPooler(
  (dense): Linear(in_features=128, out_features=128, bias=True)
  (activation): Tanh()
)
bert.pooler.dense Linear(in_features=128, out_features=128, bias=True)
bert.pooler.activation Tanh()
dropout Dropout(p=0.1, inplace=False)
classifier Linear(in_features=128, out_features=2, bias=True)

Printing the module directly results in a very long output because it expands and prints each module in detail. By contrast, children() avoids the issue of printing duplicates, as it only prints the outermost layers, unlike modules() which prints every 'inner layer' as well.

However, when making modifications to the model, using modules() can be more convenient since it directly maps names to layers.


parameters() or named_parameters()

parameters() returns all the learnable parameters in a model. If you only want to see the model's parameters, you can directly use parameters().

By the way, this is the method we usually pass into an optimizer.

# coding: utf-8
from transformers import AutoModelForSequenceClassification


def main():
    model = AutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny")
    for name, param in model.named_parameters():
        print(name, param)


if __name__ == "__main__":
    main()



I won’t include the printed results here, as you would see a bunch of weight values which might be overwhelming.



In short:

  • children() iterates over sub-modules
  • modules() iterates over sub-modules and their sub-modules until fully expanded
  • parameters() prints the learnable parameters

References


Read More

Leave a ReplyCancel reply

Exit mobile version