Skip to content

[已解決][PyTorch] RuntimeError: Attempting to deserialize object on CUDA device 3 but torch.cuda.device_count() is 1.

這是一個比較奇怪的問題:

RuntimeError: Attempting to deserialize object on CUDA
 device 3 but torch.cuda.device_count() is 1.

我是在讀取已經訓練好的模型時發生這個錯誤的。我的理解是,因為我在訓練時是使用第 3 號 GPU 來進行訓練,可是測試的時候我是在不同裝置上,而該裝置卻又只有一片 GPU,所以在讀取模型的時候找不到該片 GPU 號碼,故無法讀入。

值得注意的是,使用 to() 或是 cuda() 無法解決這個問題,因為這兩者皆是將資料讀取進來之後才進行裝置的轉換,但是現在的問題是『資料無法讀取進來』。

在網路上查了一下,找到了一個可以成功讀取的方法。

torch.load("MODEL_NAME", map_location='cpu')



比如說在讀取模型的時候,在後方的參數 map_location 直接設定 cpu 或是其他可用 GPU,這樣一來,在讀取的時候就會自動使用該裝置存取資料。

這算是我基本功不好、沒弄熟 PyTorch 讀取機制才會犯的錯誤吧!希望以後不會犯同樣的錯誤。


References


Read More

Leave a Reply