Skip to content

[PyTorch] 如何儲存模型、讀取模型

Last Updated on 2021-05-21 by Clay

在使用 PyTorch 這個好用的 Python 深度學習框架進行模型的訓練時,常常會有剛學習的人忘記把訓練好的模型『儲存』起來 —— 甚至根本沒意識到這回事,以為每次要使用都必須訓練一次。(其實這是我的黑歷史 XDDD)

將模型儲存起來、再要應用的時候『讀取』進來,這是非常非常重要的事情。

比較詳細的實戰其實可以參考我之前寫過的《使用 PyTorch 搭建 GAN 模型產生 MNIST 圖片》這篇文章 —— 我在這篇文章中使用的便是 Training 跟 Test 分開的作法,其中就是模型的儲存以及讀取。

當然,官方也是有這方面的教學的,大家也可以參考他們網站的作法,會比我詳盡不少:https://pytorch.org/tutorials/beginner/saving_loading_models.html (我只是挑個簡單直覺的紀錄起來,哈哈)

那麼,以下我就開始紀錄吧!


儲存模型

儲存模型其實非常單純,直接使用以下程式碼儲存即可:

torch.save(Model, 'Save_File_Name.pth')



Model 為我們模型的物件變數名。然後我們便會在當前目錄底下看到我們儲存名稱的 pth 模型擋了。

值得注意的是,此種方法也將模型整個儲存起來,並不是只有權重而已,所以讀取進來後,並不用再重新定義模型層。


讀取模型

就像剛才所說的一般,我們只需要直接讀取進來即可:

torch.save(Model, 'Save_File_Name.pth')



然後我們便可以直接開始使用模型了。不過如果沒有打算要繼續往下訓練模型、而是只打算拿來應用的話,記得將模型設定為『評估模式』:

model.eval()



並將測試資料放入模型的區塊放置在底下程式碼下方:

with torch.no_grad():



這樣就不會再繼續往下訓練了。

1 thought on “[PyTorch] 如何儲存模型、讀取模型”

Leave a Reply取消回覆

Exit mobile version