Skip to content

使用 Transformers 套件中的 AutoModel.from_pretrained() 讀取自定義模型

Last Updated on 2024-08-22 by Clay

時至今日有許多的 AI 應用、開源專案是以 HuggingFace 開源的 transformers 套件為基底下去開發的,有許多的模型與套件也都是寫成兼容 transformers 的格式、甚至擁有一樣的函式跟方法,才更容易為人所接受。

在這樣的前提下,我偶然使用了一個開源的訓練框架,它很好地封裝了 Transformer 架構的自動讀取 —— 但一個不得不解決的問題是我希望使用我自定義的模型去做實驗;我嘗試了幾個解決方法,目的是希望使用 AutoModel.from_pretrained() 的時候,只要傳入我本地端的模型路徑,就可以正確使用我自定義的模型架構,於是就把成功的方法紀錄於本篇筆記中。

簡單來說有分成以下兩種:

  1. AutoConfig 和 AutoModel 的 register() 註冊方法
  2. 模型設定檔 Config 的 "auto_map" 參數設定

自定義模型

假設我們有以下自定義的雙向注意力機制 Mistral 模型(當然,我是繼承 Mistral 下去改的):

from transformers import MistralConfig, MistralModel


# Rename
class MistralBiConfig(MistralConfig):
    model_type = "bimistral"
    def __init__(self, **kwargs):
        super().__init__(**kwargs)


class MistralBiModel(MistralModel):
    _no_split_modules = ["ModifiedMistralDecoderLayer"]
    config_class = str(type(MistralBiConfig()))

    def __init__(self, config: MistralConfig):
        MistralPreTrainedModel.__init__(self, config)
        ...


需要注意的是,在我們註冊 Transformer 自定義模型時,會需要讓 config.model_type 為一個不與現有任何架構衝突的名字、然後 model.config_class 必須為自己 config 的類別。這一步很重要,因為 Transformer 真的會去檢查你有沒有設定對。


AutoConfig 和 AutoModel 的 register() 註冊方法

首先來紀錄一下註冊方法,使用此方法註冊了之後,你就可以直接用 from transformers import MistralBiConfig, MistralBiModel 了。

from transformers import AutoModel, AutoConfig

AutoConfig.register("bimistral", MistralBiConfig)
AutoModel.register(MistralBiConfig, MistralBiModel)


之後,我們就可以在模型的 config.py 裡面更改 model_type 為剛剛註冊的 "bimistral",就可以使用 AutoModel.from_pretrained() 來自動讀取這個模型架構了。

不過要注意的是,我自己測試時需要每次在腳本的開頭都要註冊才會生效,並非永遠常駐 transformers 套件內的;不過這樣也比較合理。


模型設定檔 Config 的 "auto_map" 參數設定

這個方法更簡單一點,不過卻會需要讓我動到一些框架內的原始碼。

{
  "_name_or_path": "Mistral-7B",
  "model_type": "mistral",
  "architectures": [
    "MistralForCausalLM"
  ],
  "auto_map": {
    "AutoModel": "modeling_bimistral.MistralBiModel",
    "AutoModelForCausalLM": "modeling_bimistral.MistralBiForCausalLM",
    "AutoModelForSequenceClassification": "modeling_bimistral.MistralBiForSequenceClassification"
  },
...


我們只需要把自定義模型架構放在一個特定的 .py 檔案內,就可以讓 AutoModel 讀取指定的架構。不過,在讀取時,會需要加上 trust_remote_code=True 的參數,所以才說仍然會讓我去動到我要使用的框架的原始碼。(雖然也沒什麼不好)


References


Read More

Leave a Reply取消回覆

Exit mobile version