Skip to content

[PyTorch] Introduction Of “squeeze()” and “unsqueeze()” Functions

When I start to learn PyTorch framework, in order to build the models in PyTorch, I read and executed many sample programs in official teaching documents. At that time, I often saw squeeze(), unsqueeze() and other functions in the sample code but I did not understand the purpose of these functions.

In fact, the use method of squeeze() and unsqueeze() functions is very easy-to-know: squeeze() can remove dimensions, and unsqueeze() can increase dimensions, just like squeeze and unsqueeze.

In may not be clear just to explain directly, the following will explain directly with sample code.


Sample Code

First, let’s look at squeeze() function, which can remove dimensions.

# coding: utf-8
import torch


data = torch.tensor([
    [[0, 1, 2],
     [3, 4, 5],
     [6, 7, 8],]
])

print('Shape:', data.shape)


# squeeze()
squeeze_data = data.squeeze(0)
print('squeeze data:', squeeze_data)
print('squeeze(0) shape:', squeeze_data.shape)



Output:

Shape: torch.Size([1, 3, 3])
squeeze data: tensor([
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
])
squeeze(0) shape: torch.Size([3, 3])


We can see that the original tensor dimension is [1, 3, 3], but the outermost dimension is removed after the squeeze() function, and the current dimension is [3, 3].


So what about unsqueeze() function?

# coding: utf-8
import torch


data = torch.tensor([
    [[0, 1, 2],
     [3, 4, 5],
     [6, 7, 8],]
])

print('Shape:', data.shape)


# unsqueeze()
unsqueeze_data = data.unsqueeze(0)
print('unsqueeze data:', unsqueeze_data)
print('unsqueeze(0) shape:', unsqueeze_data.shape)



Output:

Shape: torch.Size([1, 3, 3])
unsqueeze data: tensor([[[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]]])
unsqueeze(0) shape: torch.Size([1, 1, 3, 3])

The difference from the previous result is that the current tensor dimension has one more dimension from [1, 3, 3] to [1, 1, 3, 3]. This is the difference between the two.

Finally, let me add that in fact, we can use the view() function to complete these two methods of removing/increasing dimensions. For example, if we want to add another dimension to the data variable of dimension [1, 3, 3], we only need to write:

data = data.view([1, 1, 3, 3])



To complete our needs.


References


Read More

Leave a Reply