Skip to content

[PyTorch] How To Use `squeeze()` and `unsqueeze()` Functions

Last Updated on 2024-08-07 by Clay

When I first started using the PyTorch framework to gradually learn how to build models, I read and executed many examples from the official tutorials. At that time, I often encountered functions like squeeze() and unsqueeze() in example code, but I didn't quite understand their purposes.

Actually, the purposes of squeeze() and unsqueeze() are quite simple:

  • squeeze() can remove dimensions
  • unsqueeze() can add dimensions

Just as the names suggest: squeeze (compress) and unsqueeze (expand).

Simply explaining it may not be very clear, so let's directly execute some example code using squeeze() and unsqueeze().


Example Code

Let's first look at squeeze(), which can remove dimensions.

import torch


data = torch.tensor([
    [[0, 1, 2],
     [3, 4, 5],
     [6, 7, 8],]
])

print('Shape:', data.shape)


# squeeze()
squeeze_data = data.squeeze(0)
print('squeeze data:', squeeze_data)
print('squeeze(0) shape:', squeeze_data.shape)


Output:

Shape: torch.Size([1, 3, 3])
squeeze data: tensor([
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
])
squeeze(0) shape: torch.Size([3, 3])

We can see that the original tensor shape was [1, 3, 3], but after using the squeeze() function, the outermost dimension was removed, and the shape is now [3, 3].

What about unsqueeze() then?

import torch


data = torch.tensor([
    [[0, 1, 2],
     [3, 4, 5],
     [6, 7, 8],]
])

print('Shape:', data.shape)


# unsqueeze()
unsqueeze_data = data.unsqueeze(0)
print('unsqueeze data:', unsqueeze_data)
print('unsqueeze(0) shape:', unsqueeze_data.shape)


Output:

Shape: torch.Size([1, 3, 3])
unsqueeze data: tensor([[[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]]])
unsqueeze(0) shape: torch.Size([1, 1, 3, 3])


Unlike the previous result, the tensor shape has increased from [1, 3, 3] to [1, 1, 3, 3]. This is the difference between the two functions.

Finally, a note: we can also use the view() function to achieve the same dimension addition/removal. For example, if we want to add a dimension to the [1, 3, 3] shaped data variable, we just need to write:

data = data.view([1, 1, 3, 3])


This can achieve our requirement, but we must ensure the new shape matches the total number of elements in the original tensor.


References


Read More

Leave a Reply