Last Updated on 2021-07-01 by Clay
雖然僅僅只是個人體感,不過我認為在使用 PyTorch 的過程中,最容易遇到的報錯有兩個 —— 一個是模型的 ""Mismatch、另外一個就是我今天所紀錄的:
RuntimeError: Expected object of scalar type Float but got scalar type Long for argument
當然,這個報錯有很多種不同的樣貌,比方說預期接受 Long,但是資料型態卻是 Float ..... 這樣反過來的情況也有。有時候,甚至是各式各樣的資料型態無法匹配的問題。
解決方法其實說穿了異常單純,基本上只有兩個步驟:
- 定位問題發生的地方,比如說 Loss Function、Model 中的資料調整。
- 將錯誤的資料型態更正為函式可以接受的
比方說,今天我原本定義的 Loss function 如下:
# Loss def loss_function(inputs, targets): return nn.BCELoss()(inputs, targets)
然後我的 "targets" 報錯了:畢竟 nn.BCELoss() 接受的是 Float 資料型態的。那麼,我就應該將 targets 修正為:
# Loss def loss_function(inputs, targets): return nn.BCELoss()(inputs, targets.float())
這樣一來就能正常執行了。