Skip to content

PyTorch 框架中的 squeeze()、unsqueeze() 用途

PyTorch is a famous Python deep learning framework

前言

在剛開始接觸 PyTorch 這個框架時,為了逐漸掌握 PyTorch 搭建模型的方法,我閱讀、執行過許多官方教學文件上的範例程式。那個時候,經常能在範例程式當中見到 squeeze()unsqueeze() 等函式,卻不太明白這兩個函式究竟有什麼樣的用途。

其實,squeeze()unsqueeze() 的用途很單純:squeeze() 能夠去除維度、unsqueeze() 則能增加維度,正如 squeeze(擠出)和 unsqueeze(鬆開)的含意。

光是直接說明可能不太清楚,以下直接執行一段 squeeze()unsqueeze() 的範例程式碼。


範例程式碼

首先來看能夠去除維度的 squeeze()

# coding: utf-8
import torch


data = torch.tensor([
    [[0, 1, 2],
     [3, 4, 5],
     [6, 7, 8],]
])

print('Shape:', data.shape)


# squeeze()
squeeze_data = data.squeeze(0)
print('squeeze data:', squeeze_data)
print('squeeze(0) shape:', squeeze_data.shape)


Output:

Shape: torch.Size([1, 3, 3])
squeeze data: tensor([
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]
])
squeeze(0) shape: torch.Size([3, 3])

我們可以看到原始的張量維度為 [1, 3, 3],但是在經過 squeeze() 函式後最外層的維度被去除了,現在的維度為 [3, 3]。

那麼 unsqueeze() 又如何呢?

# coding: utf-8
import torch


data = torch.tensor([
    [[0, 1, 2],
     [3, 4, 5],
     [6, 7, 8],]
])

print('Shape:', data.shape)


# unsqueeze()
unsqueeze_data = data.unsqueeze(0)
print('unsqueeze data:', unsqueeze_data)
print('unsqueeze(0) shape:', unsqueeze_data.shape)


Output:

Shape: torch.Size([1, 3, 3])
unsqueeze data: tensor([[[
[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]]])
unsqueeze(0) shape: torch.Size([1, 1, 3, 3])

跟剛才的結果不一樣的是,現在的張量維度從 [1, 3, 3] 多了一維,成為了 [1, 1, 3, 3]。這就是兩者之間不同的地方。

最後補充一下,其實這兩種去除/增加維度的方法,我們都可以透過 view() 函式來完成。比方說我們要將 [1, 3, 3] 維度的 data 變數增加一個維度,我們只需要這樣寫:

data = data.view([1, 1, 3, 3])


即可完成我們的需求。


References


Read More

Tags:

Leave a Reply