Last Updated on 2024-08-07 by Clay
When I first started using the PyTorch framework to gradually learn how to build models, I read and executed many examples from the official tutorials. At that time, I often encountered functions like squeeze()
and unsqueeze()
in example code, but I didn’t quite understand their purposes.
Actually, the purposes of squeeze()
and unsqueeze()
are quite simple:
squeeze()
can remove dimensionsunsqueeze()
can add dimensions
Just as the names suggest: squeeze (compress) and unsqueeze (expand).
Simply explaining it may not be very clear, so let’s directly execute some example code using squeeze()
and unsqueeze()
.
Example Code
Let’s first look at squeeze()
, which can remove dimensions.
import torch
data = torch.tensor([
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],]
])
print('Shape:', data.shape)
# squeeze()
squeeze_data = data.squeeze(0)
print('squeeze data:', squeeze_data)
print('squeeze(0) shape:', squeeze_data.shape)
Output:
Shape: torch.Size([1, 3, 3]) squeeze data: tensor([ [0, 1, 2], [3, 4, 5], [6, 7, 8] ]) squeeze(0) shape: torch.Size([3, 3])
We can see that the original tensor shape was [1, 3, 3], but after using the squeeze()
function, the outermost dimension was removed, and the shape is now [3, 3].
What about unsqueeze()
then?
import torch
data = torch.tensor([
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],]
])
print('Shape:', data.shape)
# unsqueeze()
unsqueeze_data = data.unsqueeze(0)
print('unsqueeze data:', unsqueeze_data)
print('unsqueeze(0) shape:', unsqueeze_data.shape)
Output:
Shape: torch.Size([1, 3, 3]) unsqueeze data: tensor([[[ [0, 1, 2], [3, 4, 5], [6, 7, 8]]]]) unsqueeze(0) shape: torch.Size([1, 1, 3, 3])
Unlike the previous result, the tensor shape has increased from [1, 3, 3] to [1, 1, 3, 3]. This is the difference between the two functions.
Finally, a note: we can also use the view()
function to achieve the same dimension addition/removal. For example, if we want to add a dimension to the [1, 3, 3] shaped data
variable, we just need to write:
data = data.view([1, 1, 3, 3])
This can achieve our requirement, but we must ensure the new shape matches the total number of elements in the original tensor.
References
- https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
- https://stackoverflow.com/questions/61598771/pytorch-squeeze-and-unsqueeze
- https://deeplizard.com/learn/video/fCVuiW9AFzY