Last Updated on 2023-12-04 by Clay
Problem
In early 2023, The pytorch 2.0 version added a new function that is torch.compile()
, It cloud be accelerate the speed of model training or inference. By using precision training with compiling, it always make my training time reduce to half.
But there is an obviously problem in the compiled model. That is if we saved the compiled model state dictionary, and we load it, we may get the following error message (for example: XLMRobertaModel):
state_dict = torch.load("./checkpoints/epoch-1-step-48000.ckpt")
model = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path)
model.load_state_dict(state_dict=state_dict)
Output:
RuntimeError: Error(s) in loading state_dict for WinbertaModel: Missing key(s) in state_dict: "xlm_roberta.embeddings.word_embeddings.weight", ...
The error message "xlm_roberta.embeddings.word_embeddings.weight" tell us it cloud not be founding the corresponding key in state_dict
.
Now, we print out the state_dict
.
for key in state_dict:
print(key)
Output:
_orig_mod.xlm_roberta.embeddings.word_embeddings.weight _orig_mod.xlm_roberta.embeddings.position_embeddings.weight _orig_mod.xlm_roberta.embeddings.token_type_embeddings.weight _orig_mod.xlm_roberta.embeddings.LayerNorm.weight _orig_mod.xlm_roberta.embeddings.LayerNorm.bias ...
In fact, the keys of XLMRobertaModel layer are existed. But after compiling, A new prefix _orig_mod. was add to the every layer name.
Solution
I am not sure whether pytorch latest version solve it or not. But my pytorch is installed on the remote server so I can not change it.
So, the brute force method is, we can remove the prefix. Thanks God, I can use model to load the weights successfully.
for key in list(state_dict.keys()):
state_dict[key.replace("_orig_mod.", "")] = state_dict.pop(key)
for key in state_dict:
print(key)
Output:
xlm_roberta.embeddings.word_embeddings.weight xlm_roberta.embeddings.position_embeddings.weight xlm_roberta.embeddings.token_type_embeddings.weight xlm_roberta.embeddings.LayerNorm.weight xlm_roberta.embeddings.LayerNorm.bias ...
And we load again.
model.load_state_dict(state_dict=state_dict)
Output:
model.load_state_dict(state_dict=state_dict)
References
- Make compiled models serializable · Issue #101107 · pytorch/pytorch · GitHub
- Expected _orig_mod to NOT be FullyShardedDataParallel ...
Read More
- [Solved] RuntimeError: OrderedDict mutated during iteration - OrderedDict Could Not Changed
- [Solved] Some weights of the model checkpoint at distilbert-base-multilingual-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias']