Skip to content

[已解決][PyTorch] LSTM RuntimeError: input must have 3 dimensions, got 2

Last Updated on 2021-06-28 by Clay

今天,我在使用 PyTorch 搭建 LSTM 模型的時候發生了以下這樣的報錯:

 LSTM RuntimeError: input must have 3 dimensions, got 2

這個問題讓我意外了一下,因為之前我的訓練資料丟進來沒有這樣的報錯啊。煩惱了好一會兒、看了好一會兒,我這才決定將我模型輸入的 Tensor shape 印出來看看。

其實說穿了,就是我輸入的資料維度不對。我忘記做調整維度的工作了。

基本上,LSTM 資料輸入的格式為: (batch_size, seq_len, input_size),所以這個報錯多發生於輸入的張量維度錯誤,調整之後便可正常運行。

要注意的是,PyTorch 的 nn.LSTM() 有個參數為 “batch_first”,若為 True 則輸入格式 batch_size 是第一位,若 “batch_first” 為 False 則 batch_size 與 seq_len 位置交換。

Leave a Reply取消回覆

Exit mobile version