Skip to content

[PyTorch] 確認模型的狀態屬於 train() 或是 eval()

今天我在查看 Hugging Face 的 Transformers 套件官方 Document 時,赫然發現一個恐怖的事情 —— 使用 Transformers 時所調用的 from_pretrained(),在讀取預訓練進來的時候,整個模型的狀態是屬於 eval() 、也就是評估模式的。

至於什麼是評估模式呢?基本上和訓練模型最大的差異在於 dropout 以及 batch normalization 是會有差異的。如果今天在訓練模型,最好是確定模型處於訓練模式下的。

所以我辛辛苦苦訓練了老半天、敢情根本沒有訓練到…

為了要確認我的模型是處於哪種狀態,我找到了確認模型狀態的方法,並紀錄於下。


確認模型狀態

假設今天我們有個模型,變數名為 model,那麼我們可以使用以下指令確認模型狀態:

model.training



如果模型返回 True,則代表正處於訓練模式;反之,如果返回為 False,那則代表處於評估模式。

順帶一提,我們若要將模型從評估模式轉為訓練模式,可以使用:

model.train()



而要將模型從訓練模式轉為評估模式,則可以使用:

model.eval()



祝大家不會跟我一樣,犯了搞錯模型狀態的錯誤。


References


Read More

Leave a Reply