Skip to content

[Solved][PyTorch] Model State Dict Files Have The "_orig_mod" Prefix After Compiling

Last Updated on 2023-12-04 by Clay

Problem

In early 2023, The pytorch 2.0 version added a new function that is torch.compile(), It cloud be accelerate the speed of model training or inference. By using precision training with compiling, it always make my training time reduce to half.

But there is an obviously problem in the compiled model. That is if we saved the compiled model state dictionary, and we load it, we may get the following error message (for example: XLMRobertaModel):

state_dict = torch.load("./checkpoints/epoch-1-step-48000.ckpt")

model = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path)
model.load_state_dict(state_dict=state_dict)


Output:

RuntimeError: Error(s) in loading state_dict for WinbertaModel:
	Missing key(s) in state_dict: "xlm_roberta.embeddings.word_embeddings.weight", ...

The error message "xlm_roberta.embeddings.word_embeddings.weight" tell us it cloud not be founding the corresponding key in state_dict.

Now, we print out the state_dict.

for key in state_dict:
    print(key)


Output:

_orig_mod.xlm_roberta.embeddings.word_embeddings.weight
_orig_mod.xlm_roberta.embeddings.position_embeddings.weight
_orig_mod.xlm_roberta.embeddings.token_type_embeddings.weight
_orig_mod.xlm_roberta.embeddings.LayerNorm.weight
_orig_mod.xlm_roberta.embeddings.LayerNorm.bias
...

In fact, the keys of XLMRobertaModel layer are existed. But after compiling, A new prefix _orig_mod. was add to the every layer name.


Solution

I am not sure whether pytorch latest version solve it or not. But my pytorch is installed on the remote server so I can not change it.

So, the brute force method is, we can remove the prefix. Thanks God, I can use model to load the weights successfully.

for key in list(state_dict.keys()):
    state_dict[key.replace("_orig_mod.", "")] = state_dict.pop(key)

for key in state_dict:
    print(key)


Output:

xlm_roberta.embeddings.word_embeddings.weight
xlm_roberta.embeddings.position_embeddings.weight
xlm_roberta.embeddings.token_type_embeddings.weight
xlm_roberta.embeddings.LayerNorm.weight
xlm_roberta.embeddings.LayerNorm.bias
...


And we load again.

model.load_state_dict(state_dict=state_dict)


Output:

model.load_state_dict(state_dict=state_dict)

References


Read More

Leave a ReplyCancel reply

Exit mobile version