Skip to content

[PyTorch] 如何使用 pad_packed_sequence 和 pack_padded_sequence 調整可變長度序列批次

Last Updated on 2021-07-24 by Clay

在我們使用 PyTorch 搭建 RNN 與其各種變體 (比如 LSTM、GRU) 的模型時,若搭配 PyTorch 所提供的 Embedding 層當作模型第一層的嵌入層,那麼,我們經常會碰到不同長度序列的文章。

很多人會推薦使用 pack_padded_sequence 和 pad_packed_sequence 來調整可變長度序列的句子。以下,就來一步步介紹該如何使用這兩個函式。

除此之外,文章中還會使用到 PyTorch 所提供的 pad() 函式來進行 Padding (將序列填充 "0" 到同樣長度) 以及使用 torch.cat() 來串接不同序列,詳細的介紹可以參考文章末的連結。


示範程式碼

簡單來說,pack_padded_sequence() 是用來壓縮序列的,而 pad_packed_sequence() 則是用來展開序列成原本形狀的。

以下是一個簡單的示範,扣掉註解和印出序列的部份,實際程式碼不到 15 行。

# coding: utf-8
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


# Sequences
a = torch.tensor([1, 2])
b = torch.tensor([3, 4, 5])
c = torch.tensor([6, 7, 8, 9])
print('a:', a)
print('b:', b)
print('c:', c)

# Settings
seq_lens = [len(a), len(b), len(c)]
max_len = max(seq_lens)


# Zero padding
a = F.pad(a, (0, max_len-len(a)))
b = F.pad(b, (0, max_len-len(b)))
c = F.pad(c, (0, max_len-len(c)))


# Merge the sequences
seq = torch.cat((a, b, c), 0).view(-1, max_len)
print('Sequence:', seq)


# Pack
packed_seq = pack_padded_sequence(seq, seq_lens, batch_first=True, enforce_sorted=False)
print('Pack:', packed_seq)


# Unpack
unpacked_seq, unpacked_lens = pad_packed_sequence(packed_seq, batch_first=True)
print('Unpack:', unpacked_seq)
print('length:', unpacked_lens)


# Reduction
a = unpacked_seq[0][:unpacked_lens[0]]
b = unpacked_seq[1][:unpacked_lens[1]]
c = unpacked_seq[2][:unpacked_lens[2]]
print('Recutions:')
print('a:', a)
print('b:', b)
print('c:', c)



Output:

a: tensor([1, 2])
b: tensor([3, 4, 5])
c: tensor([6, 7, 8, 9])

Sequence: 
tensor([[1, 2, 0, 0],
        [3, 4, 5, 0],
        [6, 7, 8, 9]])

Pack: 
PackedSequence(data=tensor([6, 3, 1, 7, 4, 2, 8, 5, 9]),  
               batch_sizes=tensor([3, 3, 2, 1]), 
               sorted_indices=tensor([2, 1, 0]), 
               unsorted_indices=tensor([2, 1, 0]))

Unpack: 
tensor([[1, 2, 0, 0],
        [3, 4, 5, 0],
        [6, 7, 8, 9]])

length: tensor([2, 3, 4])

Recutions:
a: tensor([1, 2])
b: tensor([3, 4, 5])
c: tensor([6, 7, 8, 9])

首先,我有 a, b, c 三組不同長度的序列,我要做的只有下面幾個步驟:

  1. 紀錄每個序列原始長度
  2. 決定填充的最大長度
  3. 填充每個序列至同樣的長度
  4. 將合併後的序列使用 pack_padded_sequence() 壓縮
  5. 將壓縮後的序列使用 pad_packed_sequence() 還原本來形狀
  6. 根據紀錄的原始長度恢復原始序列尺寸

可以看到,我們能透過 pack_padded_sequence() 將序列壓縮,再用 pad_packed_sequence() 還原,是可以將資料完整恢復的。


References


Read More

Leave a Reply取消回覆

Exit mobile version