Last Updated on 2021-05-21 by Clay
在使用 PyTorch 這個好用的 Python 深度學習框架進行模型的訓練時,常常會有剛學習的人忘記把訓練好的模型『儲存』起來 —— 甚至根本沒意識到這回事,以為每次要使用都必須訓練一次。(其實這是我的黑歷史 XDDD)
將模型儲存起來、再要應用的時候『讀取』進來,這是非常非常重要的事情。
比較詳細的實戰其實可以參考我之前寫過的《使用 PyTorch 搭建 GAN 模型產生 MNIST 圖片》這篇文章 —— 我在這篇文章中使用的便是 Training 跟 Test 分開的作法,其中就是模型的儲存以及讀取。
當然,官方也是有這方面的教學的,大家也可以參考他們網站的作法,會比我詳盡不少:https://pytorch.org/tutorials/beginner/saving_loading_models.html (我只是挑個簡單直覺的紀錄起來,哈哈)
那麼,以下我就開始紀錄吧!
儲存模型
儲存模型其實非常單純,直接使用以下程式碼儲存即可:
torch.save(Model, 'Save_File_Name.pth')
Model 為我們模型的物件變數名。然後我們便會在當前目錄底下看到我們儲存名稱的 pth 模型擋了。
值得注意的是,此種方法也將模型整個儲存起來,並不是只有權重而已,所以讀取進來後,並不用再重新定義模型層。
讀取模型
就像剛才所說的一般,我們只需要直接讀取進來即可:
torch.save(Model, 'Save_File_Name.pth')
然後我們便可以直接開始使用模型了。不過如果沒有打算要繼續往下訓練模型、而是只打算拿來應用的話,記得將模型設定為『評估模式』:
model.eval()
並將測試資料放入模型的區塊放置在底下程式碼下方:
with torch.no_grad():
這樣就不會再繼續往下訓練了。
讀取模型文中寫”torch.save”,正確應該是”torch.load”?