Skip to content

[PyTorch] 使用 view() 和 permute() 轉換維度

Last Updated on 2021-07-25 by Clay

PyTorch 是一個基於 Python 的深度學習框架,我們可以藉由 PyTorch 所包裝好的模組、函式,輕易地實作我們想要實現的模型架構。而說到深度學習,就不得不提到使用 GPU 的平行化運算,提到 GPU 的平行化運算,就一定得說到我們得將輸入神經元的『維度』固定,好實現平行化運算。

所以,整理我們輸入張量 (Tensor) 的維度與形狀,自然就是得費心處理的工作了。

本篇文章希望紀錄在 PyTorch 當中改變張量形狀的兩個重要的函式 —— view() 以及 permute(),這兩個函式都是用於轉換張量維度,但是轉換的方法與用處卻不太一樣,使用上需要稍微注意一下。


permute()

permute() 主要用於維度的『交換』,並且與 view() 不同,是會打亂張量之間元素的順序的。我們可以來看一段簡單的範例:

# coding: utf-8
import torch


inputs = [[[1, 2 ,3], [4, 5, 6]],
          [[7, 8, 9], [10, 11, 12]]]

inputs = torch.tensor(inputs)
print(inputs)
print('Inputs:', inputs.shape)



Output:

tensor([[[ 1, 2, 3],
         [ 4, 5, 6]],
     
        [[ 7, 8, 9],
         [10, 11, 12]]])

Inputs: torch.Size([2, 2, 3])

這是一個簡單的、按照數值大小排列的張量,維度是 (2, 2, 3)。那麼,我們下面加上 permute() 來置換維度。首先要說明的是,原先的維度是有著編號的。

而 permute() 則是可以透過設定這個編號,置換維度。

# coding: utf-8
import torch


inputs = [[[1, 2 ,3], [4, 5, 6]],
          [[7, 8, 9], [10, 11, 12]]]

inputs = torch.tensor(inputs)
print(inputs)
print('Inputs:', inputs.shape)


outputs = inputs.permute(0, 2, 1)
print(outputs)
print('Outputs:', outputs.shape)



Output:

tensor([[[ 1, 2, 3],
         [ 4, 5, 6]],
        [[ 7,  8,  9],
        [10, 11, 12]]])
Inputs: torch.Size([2, 2, 3])

tensor([[[ 1, 4],
         [ 2, 5],
         [ 3, 6]],

        [[ 7, 10],
         [ 8, 11],
         [ 9, 12]]])
Outputs: torch.Size([2, 3, 2])

可以看到,維度交換了,張量內的元素順序也會跟著改變。目前看到 permute() 的使用時機,多半於需要置換

(batch_size, sequence_length, vector_size)

這樣的輸入。這在 RNN 與其變體的模型層中時常碰到。


view()

和 permute() 相比,view() 不會打亂元素順序,也自由得多。比方說,我們這樣將剛才的範例改寫:

# coding: utf-8
import torch


inputs = [[[1, 2 ,3], [4, 5, 6]],
          [[7, 8, 9], [10, 11, 12]]]

inputs = torch.tensor(inputs)
print(inputs)
print('Inputs:', inputs.shape)


outputs = inputs.view(2, 3, 2)
print(outputs)
print('Outputs:', outputs.shape)



Output:

tensor([[[ 1, 2, 3],
         [ 4, 5, 6]],
        [[ 7, 8, 9],
         [10, 11, 12]]])
Inputs: torch.Size([2, 2, 3])

tensor([[[ 1, 2],
         [ 3, 4],
         [ 5, 6]],
        [[ 7, 8],
         [ 9, 10],
         [11, 12]]])
Outputs: torch.Size([2, 3, 2])

我們可以發現,雖然維度整理得跟使用 permute() 一模一樣,但是張量中的元素順序卻不會跟著改變。

另外,view() 並不只是可以置換維度順序,更是可以直接改變維度。比方說,我們可以將剛才的元素全部放在同一維度中:

# coding: utf-8
import torch


inputs = [[[1, 2 ,3], [4, 5, 6]],
          [[7, 8, 9], [10, 11, 12]]]

inputs = torch.tensor(inputs)
print(inputs)
print('Inputs:', inputs.shape)


outputs = inputs.view(-1)
print(outputs)
print('Outputs:', outputs.shape)



Output:

tensor([[[ 1, 2, 3],
         [ 4, 5, 6]],
        [[ 7, 8, 9],
         [10, 11, 12]]])
Inputs: torch.Size([2, 2, 3])

tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
Outputs: torch.Size([12])

就像這樣。其中,"-1" 的意思為自動計算維度,在我什麼維度形狀都沒有指定的情況,就是將所有元素放在同一維度下。

也正因為 view() 的功用非常強大,所以我們時常會在 PyTorch 的各種程式當中看到。


References


Read More

Leave a Reply