Last Updated on 2024-02-22 by Clay
問題描述
在使用 PyTorch 進行深度學習模型的建設時,我們免不了一次又一次地調整神經層與輸入輸出的形狀,這顯然是每位 AI 工程師必經的道路 —— 而在 PyTorch 的形狀變換 view()
方法中,顯然存在一個有趣的小陷阱:
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
直觀來說,PyTorch 在需要使用 .view()
來改變張量的形狀(shape)時,要求其元素需要在記憶體中是連續儲存的,而當我們在對張量進行一些操作後,可能會導致其記憶體的位置產生變化,比方說 .transpose()
以及 .permute()
等方法。
解決方法
所以在使用 .transpose()
以及 .permute()
等方法之後,如果你確定要使用繼續 view()
來改變張量的形狀,那麼必須先加上 .contiguous()
來保證張量在記憶體中是連續儲存的。
下面是一個該錯誤的複現範例:我們假設在計算完多頭注意力機制後,將其多頭的維度合併回原先的 hidden_size
大小。
import torch
batch_size = 16
seq_length = 512
num_head = 2
hidden_size = 768
inputs = torch.rand(batch_size, num_head, seq_length, int(hidden_size / num_head))
print("Shape:", inputs.shape)
inputs = inputs.permute(0, 2, 1, 3)
print("Permute Shape:", inputs.shape)
inputs = inputs.view(batch_size, seq_length, hidden_size)
print("Merge multi-head Shape:", inputs.shape)
Output:
Shape: torch.Size([16, 2, 512, 384]) Permute Shape: torch.Size([16, 512, 2, 384]) --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) /home/clay/Projects/machine_learning/transformers_from_scratch/analysis.ipynb Cell 12 line 1 11 inputs = inputs.permute(0, 2, 1, 3) 12 print("Permute Shape:", inputs.shape) ---> 14 inputs = inputs.view(batch_size, seq_length, hidden_size) 15 print("Merge multi-head Shape:", inputs.shape) RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
這就是由於張量在記憶體空間中不連續所導致的問題。一旦我們加上 .contiguous()
,那麼轉換形狀就會變得很順利。
import torch
batch_size = 16
seq_length = 512
num_head = 2
hidden_size = 768
inputs = torch.rand(batch_size, num_head, seq_length, int(hidden_size / num_head))
print("Shape:", inputs.shape)
# Wrong
# inputs = inputs.permute(0, 2, 1, 3)
# Correct
inputs = inputs.permute(0, 2, 1, 3).contiguous()
print("Permute Shape:", inputs.shape)
inputs = inputs.view(batch_size, seq_length, hidden_size)
print("Merge multi-head Shape:", inputs.shape)
Output:
Shape: torch.Size([16, 2, 512, 384]) Permute Shape: torch.Size([16, 512, 2, 384]) Merge multi-head Shape: torch.Size([16, 512, 768])
我們可以看到,最後已經順利地合併了多頭注意力的分割輸出了。
References
- RuntimeError: view size is not compatible with input tensor’s size and stride (at least one dimension spans across two contiguous subspaces)
- GitHub – view size is not compatible with input tensor’s …