Skip to content

[PyTorch] 判斷當前資料所使用的 GPU

Last Updated on 2021-07-05 by Clay

在我使用 PyTorch 訓練模型的時候,經常會發生我使用 GPU_A 去訓練模型、儲存模型,然而在測試模型效果的時候,卻不小心使用到了 GPU_B 來讀取測試資料 (我有多片 GPU 可以使用,還滿奢侈的 XDD),然後再用已經儲存好的模型 (GPU_A 訓練) 來測試 ——

這樣的結果會是,你拿清朝的劍斬明朝的官 …… 使用 GPU_A 上的模型去計算 GPU_B 上的資料。然後,我們會吃下一個 Error Message,顯示兩邊資料不在同樣的設備上。

當然,這是個可以解決的問題,比方說把要跑的資料全部都固定丟到某個裝置底下。比方我我全部丟到第一片 GPU (cuda:0) 好了。

model = model.to('cuda:0')



不過我想知道的是,究竟有沒有辦法直接看我的資料位於哪個裝置上呢?上網找了一下,發現還真的有人有這個需求,故順手紀錄在這裡。


使用 get_device() 查看

註:此方法只對 Tensor 有用,而且對於還在 CPU 上的 Tensor 好像不起作用的樣子。

import torch

a = torch.tensor([5, 3]).to('cuda:3')
print(a.get_device())



Output:

3

可以確認這方法是可行的。


References

Leave a Reply