Last Updated on 2023-02-17 by Clay
Today when I reading the document of the "Transformers" package which Hugging Face developed, I suddenly discovered the model I build, its state is "eval()", not "train()". It seems that as long as we use "from_pretrained()" method is the default state "eval()".
My God.
The model state "eval()", it freeze the dropout layer and batch normalization, so if we want to train a model, we should make sure it is in "train()" state, not "eval()".
No wonder my BERT is not performing well! ...... Of course, this is not necessarily the reason.
So I found a method to check which state of my model, and I record it at following content.
Use training
to check model state
Suppose our model named "model", we can use the following function to check its state:
model.training
If it return True, it is "train" state. If return False, it is "eval" state.
By the way, if we want to convert a "eval" state model to "train" state, we can use train()
.
model.train()
If we want to convert a "train" state model to "eval" state:
model.eval()
I wish you will not make the same mistake as you did with me.
References
- https://discuss.pytorch.org/t/model-train-and-model-eval-vs-model-and-model-eval/5744
- https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615