PyTorch a is deep learning framework based on Python, we can use the module and function in PyTorch to simple implement the model architecture we want.
When we are talking about deep learning, we have to mention the parallel computation using GPU.
When we are talking about GPU, we have to fix the dimension of the input neuron to achieve parallel computing.
So we always need to reshape the shape of our tensor data.
I want to mainly record two functions: view()
and permute()
. These functions both change the tensor data dimension, but they are used in different situations.
permute()
permute() is mainly used for the exchange of dimensions, and unlike view(), it disrupts the order of elements of tensors.
Let's take a look for an example:
# coding: utf-8 import torch inputs = [[[1, 2 ,3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] inputs = torch.tensor(inputs) print(inputs) print('Inputs:', inputs.shape)
Output:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
Inputs: torch.Size([2, 2, 3])
This is a simple tensor arranged in numerical order with dimensions (2, 2, 3). Then, we add permute()
below to replace the dimensions.
The first thing to note is that the original dimensions are numbered.
And permute()
can replace the dimension by setting this number.
# coding: utf-8 import torch inputs = [[[1, 2 ,3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] inputs = torch.tensor(inputs) print(inputs) print('Inputs:', inputs.shape) outputs = inputs.permute(0, 2, 1) print(outputs) print('Outputs:', outputs.shape)
Output:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
Inputs: torch.Size([2, 2, 3])
tensor([[[ 1, 4],
[ 2, 5],
[ 3, 6]],
[[ 7, 10],
[ 8, 11],
[ 9, 12]]])
Outputs: torch.Size([2, 3, 2])
As you can see, the dimensions are swapped, the order of the elements in the tensor will also change.
The timing of use is that there is often a need to replace the following dimensions in the input of neural networks such as RNN.
(batch_size, sequence_length, vector_size)
view()
Compared with permute()
, view()
does not disrupt the order of elements and is much more free.
For example, let's rewrite the previous example like this:
# coding: utf-8 import torch inputs = [[[1, 2 ,3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] inputs = torch.tensor(inputs) print(inputs) print('Inputs:', inputs.shape) outputs = inputs.view(2, 3, 2) print(outputs) print('Outputs:', outputs.shape)
Output:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
Inputs: torch.Size([2, 2, 3])
tensor([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[ 7, 8],
[ 9, 10],
[11, 12]]])
Outputs: torch.Size([2, 3, 2])
We can find that the dimensions are arranged the same as using permute()
, the order of the elements in the tensor will not change.
In addition, view()
can not only replace the order of dimensions, but also directly change the dimensions.
For example, we can put all the elements just now in the same dimension:
# coding: utf-8 import torch inputs = [[[1, 2 ,3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]] inputs = torch.tensor(inputs) print(inputs) print('Inputs:', inputs.shape) outputs = inputs.view(-1) print(outputs) print('Outputs:', outputs.shape)
Output:
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
Inputs: torch.Size([2, 2, 3])
tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Outputs: torch.Size([12])
like this.
"-1" means that the dimension is automatically calculated. In the case that I have not specified any dimension shape, it means that all elements are placed in the same dimension.
Because the function of view()
is very useful, we often see it in many programs of PyTorch.
References
- https://stackoverflow.com/questions/51143206/difference-between-tensor-permute-and-tensor-view-in-pytorch
- https://discuss.pytorch.org/t/different-between-permute-transpose-view-which-should-i-use/32916
- https://pytorch.org/docs/stable/tensors.html
Read More
- [PyTorch] How To Use pad_packed_sequence() And pack_padded_sequence() To Adjust Sequence Length
- [PyTorch] Use nn.Embedding() To Load Gensim Pre-trained Model Weight
- [PyTorch] Use "ModuleList" To Reduce The Line Of Code That Define The Model
- [PyTorch] Use torch.cat() To Replace The append() Operation In The List Data When Processing torch Tensor
- [PyTorch] How To Print Model Architecture And Extract Model Weights
Pingback: Pytorch Permute? Top Answer Update - Barkmanoil.com
Pingback: Pytorch Swap Dimensions? The 18 Top Answers - Barkmanoil.com