Last Updated on 2021-10-27 by Clay
It is very convenient for building a model using the PyTorch framework. Today I want to introduce how to print out the model architecture and extract the model layer weights.
What are the advantages of extracting weights?
My personal understanding is that in addition to reading the trained model and continuing to train it, sometimes we can even disassemble the original model and distribute it to different tasks to continue training.
This is very flexible.
Let me give a practical example: I used BERT as the first layer of my model, that is I used a pre-trained model trained by other teams in my model for fine-tune, and finally trained a classification model.
Now, I am not satisfied with this classification model. I hope to add more features as training data.
But I am quite satisfied with the embedding layer after the original fine tune, and I don't want to change it.
So I can extract the original model and get only the first layer, which is the embedding layer.
Freeze the embedding model, and re-build a new model for training...
In other words, my original model may look like this:
But in the new round of training, I only extracted the first layer of the original model:
Extract Weights
First, let's start with how to extract the weights of the model. First of all, the classification model I defined is constructed as follows:
# coding: utf-8 import torch.nn as nn # Settings vector_size = 300 # GRU class GRU(nn.Module): def __init__(self): super(GRU, self).__init__() self.gru = nn.GRU( input_size=vector_size, hidden_size=vector_size, num_layers=5, dropout=0.3, bidirectional=True, batch_first=True, ) self.fc = nn.Linear(vector_size, 1) self.sigmoid = nn.Sigmoid() def forward(self, inputs): out, hidden = self.gru(inputs, None) hidden = hidden[-1] outputs = self.fc(hidden.squeeze(0)) return self.sigmoid(outputs)
Its structure is very simple, there are only three GRU model layers (and five hidden layers), fully connected layers, and sigmoid() activation function.
I have trained a classifier and stored it as gru_model.pth
. So the following is how I read this trained model and print its weights
# coding: utf-8 import torch from GRU_300 import GRU # Load pre-trained model model_a = torch.load('./gru_model.pth').cpu() model_a.eval() # Display all model layer weights for name, para in model_a.named_parameters(): print('{}: {}'.format(name, para.shape))
Output:
gru.weight_ih_l0: torch.Size([900, 300])
gru.weight_hh_l0: torch.Size([900, 300])
gru.bias_ih_l0: torch.Size([900])
gru.bias_hh_l0: torch.Size([900])
gru.weight_ih_l0_reverse: torch.Size([900, 300])
gru.weight_hh_l0_reverse: torch.Size([900, 300])
gru.bias_ih_l0_reverse: torch.Size([900])
gru.bias_hh_l0_reverse: torch.Size([900])
gru.weight_ih_l1: torch.Size([900, 600])
gru.weight_hh_l1: torch.Size([900, 300])
gru.bias_ih_l1: torch.Size([900])
gru.bias_hh_l1: torch.Size([900])
gru.weight_ih_l1_reverse: torch.Size([900, 600])
gru.weight_hh_l1_reverse: torch.Size([900, 300])
gru.bias_ih_l1_reverse: torch.Size([900])
gru.bias_hh_l1_reverse: torch.Size([900])
gru.weight_ih_l2: torch.Size([900, 600])
gru.weight_hh_l2: torch.Size([900, 300])
gru.bias_ih_l2: torch.Size([900])
gru.bias_hh_l2: torch.Size([900])
gru.weight_ih_l2_reverse: torch.Size([900, 600])
gru.weight_hh_l2_reverse: torch.Size([900, 300])
gru.bias_ih_l2_reverse: torch.Size([900])
gru.bias_hh_l2_reverse: torch.Size([900])
gru.weight_ih_l3: torch.Size([900, 600])
gru.weight_hh_l3: torch.Size([900, 300])
gru.bias_ih_l3: torch.Size([900])
gru.bias_hh_l3: torch.Size([900])
gru.weight_ih_l3_reverse: torch.Size([900, 600])
gru.weight_hh_l3_reverse: torch.Size([900, 300])
gru.bias_ih_l3_reverse: torch.Size([900])
gru.bias_hh_l3_reverse: torch.Size([900])
gru.weight_ih_l4: torch.Size([900, 600])
gru.weight_hh_l4: torch.Size([900, 300])
gru.bias_ih_l4: torch.Size([900])
gru.bias_hh_l4: torch.Size([900])
gru.weight_ih_l4_reverse: torch.Size([900, 600])
gru.weight_hh_l4_reverse: torch.Size([900, 300])
gru.bias_ih_l4_reverse: torch.Size([900])
gru.bias_hh_l4_reverse: torch.Size([900])
fc.weight: torch.Size([1, 300])
fc.bias: torch.Size([1])
By calling the named_parameters()
function, we can print out the name of the model layer and its weight. For the convenience of display, I only printed out the dimensions of the weights.
You can print out the detailed weight values.
(Note: GRU_300
is a program that defined the model for me)
So, the above is how to print out the model. Next, I actually ran how to make the new model inherit the weight of pre-train.
First, use the same function named_parameters()
as before to get the weights. This time we will save the weights as dictionary data type.
# coding: utf-8 import torch from GRU_300 import GRU # Load pre-trained model model_a = torch.load('./gru_model.pth').cpu() model_a.eval() # Display all model layer weights weights = dict() for name, para in model_a.named_parameters(): weights[name] = para
Here I store the model layer name and weight into variable weights
.
# Build a new model model_b = GRU().cpu() model_b_weight = model_b.state_dict() model_b_weight.update(weights) model_b.load_state_dict(model_b_weight) model_b.eval()
I created a new GRU model and use state_dict()
to extract the shape of the weights. Then I updated the model_b_weight
with the weights extracted from the pre-train model just now using the update()
function.
Now the model_b_weight
variable means that the new model can accept weights, so we use load_state_dict()
to load the weights into the new model. In this way, the two models should be exactly the same. Below, we randomly generate a test input to check whether the output of the two models are exactly the same.
# Test inputs = torch.ones([1, 31, 300]) outputs_a = model_a.gru(inputs) outputs_b = model_b.gru(inputs) print(outputs_a[0]==outputs_b[0])
Output:
tensor([[[True, True, True, …, True, True, True],
[True, True, True, …, True, True, True],
[True, True, True, …, True, True, True],
…,
[True, True, True, …, True, True, True],
[True, True, True, …, True, True, True],
[True, True, True, …, True, True, True]]])
As you can see, the output of the two models is the same. In this way, our new model does indeed extract the weights of the pre-train model.
It is worth mentioning that the shape of the input I randomly generated corresponds to the shape of the data of my real training model, so there is no need to be too entangled in how the 31 came out, haha.
Extract Model Layer Output Shape
Extract the output of a specific model layer is even easier. In fact, it is already included in the code just now. First of all, once again, the following is my model architecture:
# coding: utf-8 import torch.nn as nn # Settings vector_size = 300 # GRU class GRU(nn.Module): def __init__(self): super(GRU, self).__init__() self.gru = nn.GRU( input_size=vector_size, hidden_size=vector_size, num_layers=5, dropout=0.3, bidirectional=True, batch_first=True, ) self.fc = nn.Linear(vector_size, 1) self.sigmoid = nn.Sigmoid() def forward(self, inputs): out, hidden = self.gru(inputs, None) hidden = hidden[-1] outputs = self.fc(hidden.squeeze(0)) return self.sigmoid(outputs)
As mentioned above, my model architecture has three parts: GRU layer, fully connected layer, and sigmoid function. So, what if I only want to extract the fully connected layer part?
# coding: utf-8 import torch from GRU_300 import GRU # Load pre-trained model model = torch.load('./gru_model.pth').cpu() model.eval() # Inputs inputs = torch.ones([1, 300]) outputs = model.fc(inputs) print('Inputs:', inputs.shape) print('Outputs:', outputs.shape)
Output:
Inputs: torch.Size([1, 300])
Outputs: torch.Size([1, 1])
Yes, in the definition of our model architecture, the name of the fully connected layer is "fc".
Therefore, we can use the model layer by directly using the fc
function of the model. In fact, we have also seen that after the 300-dimensional input passes through the fully connected layer, it becomes only one-dimensional output, which is fully compliance with the original design of our model.
So, the above is a simple note for extracting weight or model layer in PyTorch.
References
- https://discuss.pytorch.org/t/how-to-extract-learned-weights-correctly/4295/2
- https://discuss.pytorch.org/t/access-weights-of-a-specific-module-in-nn-sequential/3627
- https://discuss.pytorch.org/t/how-to-output-weight/2796
Clear explanation and proper steps to implement the task!