Skip to content

[已解決] RuntimeError: view size is not compatible with input tensor’s size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(…) instead.

問題描述

在使用 PyTorch 進行深度學習模型的建設時,我們免不了一次又一次地調整神經層與輸入輸出的形狀,這顯然是每位 AI 工程師必經的道路 —— 而在 PyTorch 的形狀變換 view() 方法中,顯然存在一個有趣的小陷阱:

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.


直觀來說,PyTorch 在需要使用 .view() 來改變張量的形狀(shape)時,要求其元素需要在記憶體中是連續儲存的,而當我們在對張量進行一些操作後,可能會導致其記憶體的位置產生變化,比方說 .transpose() 以及 .permute() 等方法。


解決方法

所以在使用 .transpose() 以及 .permute() 等方法之後,如果你確定要使用繼續 view() 來改變張量的形狀,那麼必須先加上 .contiguous() 來保證張量在記憶體中是連續儲存的。

下面是一個該錯誤的複現範例:我們假設在計算完多頭注意力機制後,將其多頭的維度合併回原先的 hidden_size 大小。

import torch

batch_size = 16
seq_length = 512
num_head = 2
hidden_size = 768

inputs = torch.rand(batch_size, num_head, seq_length, int(hidden_size / num_head))
print("Shape:", inputs.shape)

inputs = inputs.permute(0, 2, 1, 3)
print("Permute Shape:", inputs.shape)

inputs = inputs.view(batch_size, seq_length, hidden_size)
print("Merge multi-head Shape:", inputs.shape)


Output:

Shape: torch.Size([16, 2, 512, 384])
Permute Shape: torch.Size([16, 512, 2, 384])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/home/clay/Projects/machine_learning/transformers_from_scratch/analysis.ipynb Cell 12 line 1
     11 inputs = inputs.permute(0, 2, 1, 3)
     12 print("Permute Shape:", inputs.shape)
---> 14 inputs = inputs.view(batch_size, seq_length, hidden_size)
     15 print("Merge multi-head Shape:", inputs.shape)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.


這就是由於張量在記憶體空間中不連續所導致的問題。一旦我們加上 .contiguous(),那麼轉換形狀就會變得很順利。

import torch

batch_size = 16
seq_length = 512
num_head = 2
hidden_size = 768

inputs = torch.rand(batch_size, num_head, seq_length, int(hidden_size / num_head))
print("Shape:", inputs.shape)

# Wrong
# inputs = inputs.permute(0, 2, 1, 3)

# Correct
inputs = inputs.permute(0, 2, 1, 3).contiguous()
print("Permute Shape:", inputs.shape)

inputs = inputs.view(batch_size, seq_length, hidden_size)
print("Merge multi-head Shape:", inputs.shape)


Output:

Shape: torch.Size([16, 2, 512, 384])
Permute Shape: torch.Size([16, 512, 2, 384])
Merge multi-head Shape: torch.Size([16, 512, 768])


我們可以看到,最後已經順利地合併了多頭注意力的分割輸出了。


References


Read More

Leave a Reply