Skip to content

[已解決][PyTorch] 編譯(compile)後模型權重多出 “_orig_mod” 的前綴問題

問題描述

在 2023 年初,PyTorch 的 2.0 版本新增了一個 torch.compile() 的新功能,讓我們能夠在模型訓練/推理時能夠進一步提昇速度。與混合精度訓練的協同工作,經常能使我的訓練速度提昇一倍左右。

然而也正是編譯後的模型,在儲存/讀取中有一個明顯的問題,那就是當我們訓練編譯模型到一定時候,我們將其儲存權重儲存起來,在下次進行讀取時,會出現以下錯誤訊息(示範的模型是 XLMRobertaModel)。

state_dict = torch.load("./checkpoints/epoch-1-step-48000.ckpt")

model = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path)
model.load_state_dict(state_dict=state_dict)


Output:

RuntimeError: Error(s) in loading state_dict for WinbertaModel:
	Missing key(s) in state_dict: "xlm_roberta.embeddings.word_embeddings.weight", ...

看錯誤訊息,顯然是模型在讀取權重時,找不到對應鍵值的權重,比方說 “xlm_roberta.embeddings.word_embeddings.weight” 就沒有在 state_dict 中找到。

如果我們印出 state_dict,我們就會發現問題。

for key in state_dict:
    print(key)


Output:

_orig_mod.xlm_roberta.embeddings.word_embeddings.weight
_orig_mod.xlm_roberta.embeddings.position_embeddings.weight
_orig_mod.xlm_roberta.embeddings.token_type_embeddings.weight
_orig_mod.xlm_roberta.embeddings.LayerNorm.weight
_orig_mod.xlm_roberta.embeddings.LayerNorm.bias
...

其實 XLMRobertaModel 應該要有的模型層全部都有,就只是在編譯之後,儲存的權重多了一個 _orig_mod. 的前綴!


解決方法

實際上,我並不確定最新版本的 PyTorch 是否已經解決了這個問題,但我簡單瀏覽了一下至八月左右為止的討論,顯然這個問題仍舊存在 —— 並且我的伺服器 PyTorch 版本是固定的,也不好隨意升級。

所以,最簡單粗暴的解決方法,就是逕自刪除那多出來的前綴。謝天謝地,刪除完之後至少還可以正常讀取。

for key in list(state_dict.keys()):
    state_dict[key.replace("_orig_mod.", "")] = state_dict.pop(key)

for key in state_dict:
    print(key)


Output:

xlm_roberta.embeddings.word_embeddings.weight
xlm_roberta.embeddings.position_embeddings.weight
xlm_roberta.embeddings.token_type_embeddings.weight
xlm_roberta.embeddings.LayerNorm.weight
xlm_roberta.embeddings.LayerNorm.bias
...


並且再一次讀取後,會看到模型成功載入權重了。

model.load_state_dict(state_dict=state_dict)


Output:

model.load_state_dict(state_dict=state_dict)

References


Read More

Leave a Reply