Last Updated on 2024-08-07 by Clay
When I first started using the PyTorch framework to gradually learn how to build models, I read and executed many examples from the official tutorials. At that time, I often encountered functions like squeeze() and unsqueeze() in example code, but I didn’t quite understand their purposes.
Actually, the purposes of squeeze() and unsqueeze() are quite simple:
squeeze()can remove dimensionsunsqueeze()can add dimensions
Just as the names suggest: squeeze (compress) and unsqueeze (expand).
Simply explaining it may not be very clear, so let’s directly execute some example code using squeeze() and unsqueeze().
Example Code
Let’s first look at squeeze(), which can remove dimensions.
import torch
data = torch.tensor([
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],]
])
print('Shape:', data.shape)
# squeeze()
squeeze_data = data.squeeze(0)
print('squeeze data:', squeeze_data)
print('squeeze(0) shape:', squeeze_data.shape)
Output:
Shape: torch.Size([1, 3, 3]) squeeze data: tensor([ [0, 1, 2], [3, 4, 5], [6, 7, 8] ]) squeeze(0) shape: torch.Size([3, 3])
We can see that the original tensor shape was [1, 3, 3], but after using the squeeze() function, the outermost dimension was removed, and the shape is now [3, 3].
What about unsqueeze() then?
import torch
data = torch.tensor([
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],]
])
print('Shape:', data.shape)
# unsqueeze()
unsqueeze_data = data.unsqueeze(0)
print('unsqueeze data:', unsqueeze_data)
print('unsqueeze(0) shape:', unsqueeze_data.shape)
Output:
Shape: torch.Size([1, 3, 3]) unsqueeze data: tensor([[[ [0, 1, 2], [3, 4, 5], [6, 7, 8]]]]) unsqueeze(0) shape: torch.Size([1, 1, 3, 3])
Unlike the previous result, the tensor shape has increased from [1, 3, 3] to [1, 1, 3, 3]. This is the difference between the two functions.
Finally, a note: we can also use the view() function to achieve the same dimension addition/removal. For example, if we want to add a dimension to the [1, 3, 3] shaped data variable, we just need to write:
data = data.view([1, 1, 3, 3])
This can achieve our requirement, but we must ensure the new shape matches the total number of elements in the original tensor.
References
- https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
- https://stackoverflow.com/questions/61598771/pytorch-squeeze-and-unsqueeze
- https://deeplizard.com/learn/video/fCVuiW9AFzY