Last Updated on 2024-08-07 by Clay
在剛開始接觸 PyTorch 這個框架時,為了逐漸掌握 PyTorch 搭建模型的方法,我閱讀、執行過許多官方教學文件上的範例程式。那個時候,經常能在範例程式當中見到 squeeze()
、unsqueeze()
等函式,卻不太明白這兩個函式究竟有什麼樣的用途。
其實,squeeze()
和 unsqueeze()
的用途很單純:
squeeze()
能夠去除維度unsqueeze()
則能增加維度
正如 squeeze(擠出)和 unsqueeze(鬆開)的含意。
光是直接說明可能不太清楚,以下直接執行一段 squeeze()
和 unsqueeze()
的範例程式碼。
範例程式碼
首先來看能夠去除維度的 squeeze()
。
import torch
data = torch.tensor([
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],]
])
print('Shape:', data.shape)
# squeeze()
squeeze_data = data.squeeze(0)
print('squeeze data:', squeeze_data)
print('squeeze(0) shape:', squeeze_data.shape)
Output:
Shape: torch.Size([1, 3, 3]) squeeze data: tensor([ [0, 1, 2], [3, 4, 5], [6, 7, 8] ]) squeeze(0) shape: torch.Size([3, 3])
我們可以看到原始的張量維度為 [1, 3, 3],但是在經過 squeeze()
函式後最外層的維度被去除了,現在的維度為 [3, 3]。
那麼 unsqueeze()
又如何呢?
import torch
data = torch.tensor([
[[0, 1, 2],
[3, 4, 5],
[6, 7, 8],]
])
print('Shape:', data.shape)
# unsqueeze()
unsqueeze_data = data.unsqueeze(0)
print('unsqueeze data:', unsqueeze_data)
print('unsqueeze(0) shape:', unsqueeze_data.shape)
Output:
Shape: torch.Size([1, 3, 3]) unsqueeze data: tensor([[[ [0, 1, 2], [3, 4, 5], [6, 7, 8]]]]) unsqueeze(0) shape: torch.Size([1, 1, 3, 3])
跟剛才的結果不一樣的是,現在的張量維度從 [1, 3, 3] 多了一維,成為了 [1, 1, 3, 3]。這就是兩者之間不同的地方。
最後補充一下,其實這兩種去除/增加維度的方法,我們都可以透過 view()
函式來完成。比方說我們要將 [1, 3, 3] 維度的 data 變數增加一個維度,我們只需要這樣寫:
data = data.view([1, 1, 3, 3])
這樣寫可以達成同樣的需求,但需要確保新的形狀與原始張量的總元素數量一致。
References
- https://pytorch.org/docs/stable/generated/torch.unsqueeze.html
- https://stackoverflow.com/questions/61598771/pytorch-squeeze-and-unsqueeze
- https://deeplizard.com/learn/video/fCVuiW9AFzY