Last Updated on 2021-07-24 by Clay
在我們使用 PyTorch 搭建 RNN 與其各種變體 (比如 LSTM、GRU) 的模型時,若搭配 PyTorch 所提供的 Embedding 層當作模型第一層的嵌入層,那麼,我們經常會碰到不同長度序列的文章。
很多人會推薦使用 pack_padded_sequence 和 pad_packed_sequence 來調整可變長度序列的句子。以下,就來一步步介紹該如何使用這兩個函式。
除此之外,文章中還會使用到 PyTorch 所提供的 pad() 函式來進行 Padding (將序列填充 "0" 到同樣長度) 以及使用 torch.cat() 來串接不同序列,詳細的介紹可以參考文章末的連結。
示範程式碼
簡單來說,pack_padded_sequence() 是用來壓縮序列的,而 pad_packed_sequence() 則是用來展開序列成原本形狀的。
以下是一個簡單的示範,扣掉註解和印出序列的部份,實際程式碼不到 15 行。
# coding: utf-8 import torch import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence # Sequences a = torch.tensor([1, 2]) b = torch.tensor([3, 4, 5]) c = torch.tensor([6, 7, 8, 9]) print('a:', a) print('b:', b) print('c:', c) # Settings seq_lens = [len(a), len(b), len(c)] max_len = max(seq_lens) # Zero padding a = F.pad(a, (0, max_len-len(a))) b = F.pad(b, (0, max_len-len(b))) c = F.pad(c, (0, max_len-len(c))) # Merge the sequences seq = torch.cat((a, b, c), 0).view(-1, max_len) print('Sequence:', seq) # Pack packed_seq = pack_padded_sequence(seq, seq_lens, batch_first=True, enforce_sorted=False) print('Pack:', packed_seq) # Unpack unpacked_seq, unpacked_lens = pad_packed_sequence(packed_seq, batch_first=True) print('Unpack:', unpacked_seq) print('length:', unpacked_lens) # Reduction a = unpacked_seq[0][:unpacked_lens[0]] b = unpacked_seq[1][:unpacked_lens[1]] c = unpacked_seq[2][:unpacked_lens[2]] print('Recutions:') print('a:', a) print('b:', b) print('c:', c)
Output:
a: tensor([1, 2])
b: tensor([3, 4, 5])
c: tensor([6, 7, 8, 9])
Sequence:
tensor([[1, 2, 0, 0],
[3, 4, 5, 0],
[6, 7, 8, 9]])
Pack:
PackedSequence(data=tensor([6, 3, 1, 7, 4, 2, 8, 5, 9]),
batch_sizes=tensor([3, 3, 2, 1]),
sorted_indices=tensor([2, 1, 0]),
unsorted_indices=tensor([2, 1, 0]))
Unpack:
tensor([[1, 2, 0, 0],
[3, 4, 5, 0],
[6, 7, 8, 9]])
length: tensor([2, 3, 4])
Recutions:
a: tensor([1, 2])
b: tensor([3, 4, 5])
c: tensor([6, 7, 8, 9])
首先,我有 a, b, c 三組不同長度的序列,我要做的只有下面幾個步驟:
- 紀錄每個序列原始長度
- 決定填充的最大長度
- 填充每個序列至同樣的長度
- 將合併後的序列使用 pack_padded_sequence() 壓縮
- 將壓縮後的序列使用 pad_packed_sequence() 還原本來形狀
- 根據紀錄的原始長度恢復原始序列尺寸
可以看到,我們能透過 pack_padded_sequence() 將序列壓縮,再用 pad_packed_sequence() 還原,是可以將資料完整恢復的。
References
- https://pytorch.org/docs/master/generated/torch.nn.utils.rnn.pad_packed_sequence.html
- https://github.com/pytorch/pytorch/issues/1128