Skip to content

[PyTorch] How to check the model state is "train()" or "eval()"

Last Updated on 2023-02-17 by Clay

Today when I reading the document of the "Transformers" package which Hugging Face developed, I suddenly discovered the model I build, its state is "eval()", not "train()". It seems that as long as we use "from_pretrained()" method is the default state "eval()".

My God.

The model state "eval()", it freeze the dropout layer and batch normalization, so if we want to train a model, we should make sure it is in "train()" state, not "eval()".

No wonder my BERT is not performing well! ...... Of course, this is not necessarily the reason.

So I found a method to check which state of my model, and I record it at following content.


Use training to check model state

Suppose our model named "model", we can use the following function to check its state:

model.training



If it return True, it is "train" state. If return False, it is "eval" state.

By the way, if we want to convert a "eval" state model to "train" state, we can use train().

model.train()



If we want to convert a "train" state model to "eval" state:

model.eval()



I wish you will not make the same mistake as you did with me.


References


Read More

Leave a Reply