Last Updated on 2023-12-04 by Clay
問題描述
在 2023 年初,PyTorch 的 2.0 版本新增了一個 torch.compile()
的新功能,讓我們能夠在模型訓練/推理時能夠進一步提昇速度。與混合精度訓練的協同工作,經常能使我的訓練速度提昇一倍左右。
然而也正是編譯後的模型,在儲存/讀取中有一個明顯的問題,那就是當我們訓練編譯模型到一定時候,我們將其儲存權重儲存起來,在下次進行讀取時,會出現以下錯誤訊息(示範的模型是 XLMRobertaModel)。
state_dict = torch.load("./checkpoints/epoch-1-step-48000.ckpt")
model = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path)
model.load_state_dict(state_dict=state_dict)
Output:
RuntimeError: Error(s) in loading state_dict for WinbertaModel: Missing key(s) in state_dict: "xlm_roberta.embeddings.word_embeddings.weight", ...
看錯誤訊息,顯然是模型在讀取權重時,找不到對應鍵值的權重,比方說 "xlm_roberta.embeddings.word_embeddings.weight" 就沒有在 state_dict
中找到。
如果我們印出 state_dict
,我們就會發現問題。
for key in state_dict:
print(key)
Output:
_orig_mod.xlm_roberta.embeddings.word_embeddings.weight _orig_mod.xlm_roberta.embeddings.position_embeddings.weight _orig_mod.xlm_roberta.embeddings.token_type_embeddings.weight _orig_mod.xlm_roberta.embeddings.LayerNorm.weight _orig_mod.xlm_roberta.embeddings.LayerNorm.bias ...
其實 XLMRobertaModel 應該要有的模型層全部都有,就只是在編譯之後,儲存的權重多了一個 _orig_mod. 的前綴!
解決方法
實際上,我並不確定最新版本的 PyTorch 是否已經解決了這個問題,但我簡單瀏覽了一下至八月左右為止的討論,顯然這個問題仍舊存在 —— 並且我的伺服器 PyTorch 版本是固定的,也不好隨意升級。
所以,最簡單粗暴的解決方法,就是逕自刪除那多出來的前綴。謝天謝地,刪除完之後至少還可以正常讀取。
for key in list(state_dict.keys()):
state_dict[key.replace("_orig_mod.", "")] = state_dict.pop(key)
for key in state_dict:
print(key)
Output:
xlm_roberta.embeddings.word_embeddings.weight xlm_roberta.embeddings.position_embeddings.weight xlm_roberta.embeddings.token_type_embeddings.weight xlm_roberta.embeddings.LayerNorm.weight xlm_roberta.embeddings.LayerNorm.bias ...
並且再一次讀取後,會看到模型成功載入權重了。
model.load_state_dict(state_dict=state_dict)
Output:
model.load_state_dict(state_dict=state_dict)
References
- Make compiled models serializable · Issue #101107 · pytorch/pytorch · GitHub
- Expected _orig_mod to NOT be FullyShardedDataParallel ...
Read More
- [已解決] RuntimeError: OrderedDict mutated during iteration - OrderedDict 不能在迭代時發生變化
- [已解決] Some weights of the model checkpoint at distilbert-base-multilingual-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_transform.bias']