Last Updated on 2024-04-01 by Clay
問題描述
SFTTrainer 是 HuggingFace 所提供的一個進行 LLM 微調任務的訓練工具,可以快速調整多項超參數與細項配置在大型語言模型的微調任務中。其中,response_template
是訓練資料中我們必須傳遞的特殊字串模板,在這個模板字串後的所有內容,都會在訓練時參與 loss 的計算。
不過,這是需要在做 CompletionOnly 的情況下才會將 response_template
設定好。因為在非 CompletionOnly 的情境下,我們訓練資料中的每一個文字都會是模型需要計算 loss 的部份。
那麼,問題來了:如果今天我遵循著 ChatML 格式,將訓練資料轉換成以下格式:
<|im_start|>system ...<|im_end|> <|im_start|>user ...<|im_end|> <|im_start|>assistant ...<|im_end|>
我們的 response_template
自然是設定成 <|im_start|>assistant\n
,因為從這以後才是模型需要微調學習的回應部份。這很合理。
那如果我的訓練資料是多輪對話(multi-turn conversation)呢?
<|im_start|>system ...<|im_end|> <|im_start|>user ...<|im_end|> <|im_start|>assistant ...<|im_end|> <|im_start|>user ...<|im_end|> <|im_start|>assistant ...<|im_end|> <|im_start|>user ...<|im_end|> <|im_start|>assistant ...<|im_end|>
很顯然,我的訓練資料中存在著複數個 response_template
。那 SFTTrainer 究竟是從哪裡開始計算 loss 呢?
解答
網路上我們可以簡單地查到一些答案,但我認為最直接也最踏實的方式就是印出來看看。(需要注意的是,有些斷詞器,如 LlamaTokenizer
,會因為上下文而有不同的斷詞結果,所以最好斷詞後再傳入給 DataCollatorForCompletionOnlyLM
。)
from transformers import AutoModel, AutoTokenizer
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
pretrained_model_name_or_path = "/mnt/transformers_models/teknium--OpenHermes-2.5-Mistral-7B/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
response_template = "<|im_start|>assistant\n"
response_template_tokenized = tokenizer.encode(response_template, add_special_tokens=False)
collator = DataCollatorForCompletionOnlyLM(response_template=response_template_tokenized , tokenizer=tokenizer)
example = """<|im_start|>system
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>"""
example_encoded = tokenizer(example)
collator([example_encoded["input_ids"]])
Output:
{'input_ids': tensor([[ 1, 32001, 1587, 13, 1101, 32000, 28705, 13, 32001, 2188, 13, 1101, 32000, 28705, 13, 32001, 13892, 13, 1101, 32000, 28705, 13, 32001, 2188, 13, 1101, 32000, 28705, 13, 32001, 13892, 13, 1101, 32000, 28705, 13, 32001, 2188, 13, 1101, 32000, 28705, 13, 32001, 13892, 13, 1101, 32000]]), 'labels': tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1101, 32000]])}
你會看到,是一直到最後一個 response_template
(在 ChatML 中是 <im_start>assistant\n
)才開始計算 loss,其餘的都是填充 -100。
另外,如果是多輪對話,則可以傳入 instruction_template
當作斷點。你會發現,每一輪對話中的 response_template
都被計算 loss。
from transformers import AutoModel, AutoTokenizer
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
pretrained_model_name_or_path = "/mnt/transformers_models/teknium--OpenHermes-2.5-Mistral-7B/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
instruction_template = "<|im_start|>user\n"
instruction_template_tokenized = tokenizer.encode(instruction_template, add_special_tokens=False)
response_template = "<|im_start|>assistant\n"
response_template_tokenized = tokenizer.encode(response_template, add_special_tokens=False)
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template_tokenized, response_template=response_template_tokenized , tokenizer=tokenizer)
example = """<|im_start|>system
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>"""
example_encoded = tokenizer(example)
collator([example_encoded["input_ids"]])
Output:
{'input_ids': tensor([[ 1, 32001, 1587, 13, 1101, 32000, 28705, 13, 32001, 2188, 13, 1101, 32000, 28705, 13, 32001, 13892, 13, 1101, 32000, 28705, 13, 32001, 2188, 13, 1101, 32000, 28705, 13, 32001, 13892, 13, 1101, 32000, 28705, 13, 32001, 2188, 13, 1101, 32000, 28705, 13, 32001, 13892, 13, 1101, 32000]]), 'labels': tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1101, 32000, 28705, 13, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1101, 32000, 28705, 13, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1101, 32000]])}