Skip to content

[已解決] 使用 SFTTrainer 時,如果訓練資料中存在多個 response_template,會從何處開始計算 loss

問題描述

SFTTrainer 是 HuggingFace 所提供的一個進行 LLM 微調任務的訓練工具,可以快速調整多項超參數與細項配置在大型語言模型的微調任務中。其中,response_template 是訓練資料中我們必須傳遞的特殊字串模板,在這個模板字串後的所有內容,都會在訓練時參與 loss 的計算。

不過,這是需要在做 CompletionOnly 的情況下才會將 response_template 設定好。因為在非 CompletionOnly 的情境下,我們訓練資料中的每一個文字都會是模型需要計算 loss 的部份。

那麼,問題來了:如果今天我遵循著 ChatML 格式,將訓練資料轉換成以下格式:

<|im_start|>system
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>

我們的 response_template 自然是設定成 <|im_start|>assistant\n,因為從這以後才是模型需要微調學習的回應部份。這很合理。

那如果我的訓練資料是多輪對話multi-turn conversation)呢?

<|im_start|>system
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>

很顯然,我的訓練資料中存在著複數個 response_template。那 SFTTrainer 究竟是從哪裡開始計算 loss 呢?


解答

網路上我們可以簡單地查到一些答案,但我認為最直接也最踏實的方式就是印出來看看。(需要注意的是,有些斷詞器,如 LlamaTokenizer,會因為上下文而有不同的斷詞結果,所以最好斷詞後再傳入給 DataCollatorForCompletionOnlyLM。)

from transformers import AutoModel, AutoTokenizer
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

pretrained_model_name_or_path = "/mnt/transformers_models/teknium--OpenHermes-2.5-Mistral-7B/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)

response_template = "<|im_start|>assistant\n"

response_template_tokenized = tokenizer.encode(response_template, add_special_tokens=False)

collator = DataCollatorForCompletionOnlyLM(response_template=response_template_tokenized , tokenizer=tokenizer)

example = """<|im_start|>system
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>"""
example_encoded = tokenizer(example)

collator([example_encoded["input_ids"]])


Output:

{'input_ids': tensor([[    1, 32001,  1587,    13,  1101, 32000, 28705,    13, 32001,  2188,
             13,  1101, 32000, 28705,    13, 32001, 13892,    13,  1101, 32000,
          28705,    13, 32001,  2188,    13,  1101, 32000, 28705,    13, 32001,
          13892,    13,  1101, 32000, 28705,    13, 32001,  2188,    13,  1101,
          32000, 28705,    13, 32001, 13892,    13,  1101, 32000]]),
 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  1101, 32000]])}


你會看到,是一直到最後一個 response_template(在 ChatML 中是 <im_start>assistant\n)才開始計算 loss,其餘的都是填充 -100。

另外,如果是多輪對話,則可以傳入 instruction_template 當作斷點。你會發現,每一輪對話中的 response_template 都被計算 loss。

from transformers import AutoModel, AutoTokenizer
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

pretrained_model_name_or_path = "/mnt/transformers_models/teknium--OpenHermes-2.5-Mistral-7B/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)

instruction_template = "<|im_start|>user\n"
instruction_template_tokenized = tokenizer.encode(instruction_template, add_special_tokens=False)

response_template = "<|im_start|>assistant\n"
response_template_tokenized = tokenizer.encode(response_template, add_special_tokens=False)

collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template_tokenized, response_template=response_template_tokenized , tokenizer=tokenizer)

example = """<|im_start|>system
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>
<|im_start|>user
...<|im_end|>
<|im_start|>assistant
...<|im_end|>"""
example_encoded = tokenizer(example)

collator([example_encoded["input_ids"]])


Output:

{'input_ids': tensor([[    1, 32001,  1587,    13,  1101, 32000, 28705,    13, 32001,  2188,
             13,  1101, 32000, 28705,    13, 32001, 13892,    13,  1101, 32000,
          28705,    13, 32001,  2188,    13,  1101, 32000, 28705,    13, 32001,
          13892,    13,  1101, 32000, 28705,    13, 32001,  2188,    13,  1101,
          32000, 28705,    13, 32001, 13892,    13,  1101, 32000]]),
 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  1101, 32000,
          28705,    13,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  1101, 32000, 28705,    13,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  1101, 32000]])}

Referenes


Read More

Leave a Reply