Skip to content

[Solved] Where Does Loss Calculation Begin When Multiple `response_template` Exist in Training Data Using SFTTrainer?

Problem

SFTTrainer is a LLM fine-tuning tool provided by HuggingFace team, that can easily adjust many hyper-parameters and config at the fine-tuning task. In the process, response_template is the special string template we need to pass into the tool, any response right by it will be computed the loss.

However, setting up the response_template is necessary only when doing CompletionOnly tasks. In non-CompletionOnly scenarios, every piece of text in our training data contributes to the loss calculation that the model needs to perform.

So, the question arises: If today I follow the ChatML format and convert my training data into the following format:

<im_start>system
...<im_end>
<im_start>user
...<im_end>
<im_start>assistant
...<im_end>

Our response_template is naturally set to <im_start>assistant\n, since it is from this point forward that the model needs to fine-tune its learning on the response part. This is quite reasonable.

But what if my training data consists of multi-turn conversations?

<im_start>system
...<im_end>
<im_start>user
...<im_end>
<im_start>assistant
...<im_end>
<im_start>user
...<im_end>
<im_start>assistant
...<im_end>
<im_start>user
...<im_end>
<im_start>assistant
...<im_end>

Clearly, there are multiple response_templates in my training data. So, where exactly does the SFTTrainer begin to calculate loss?


Solution

We can easily find some answers online, but I believe the most direct and reliable way is to print it out and see for ourselves.

from transformers import AutoModel, AutoTokenizer
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

pretrained_model_name_or_path = "/mnt/transformers_models/teknium--OpenHermes-2.5-Mistral-7B/"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)

response_template = "<im_start>assistant\n"

response_template_tokenized = tokenizer.encode(f"{response_template}", add_special_tokens=False)[2:]

collator = DataCollatorForCompletionOnlyLM(response_template=response_template_tokenized , tokenizer=tokenizer)

example = """<im_start>system
...<im_end>
<im_start>user
...<im_end>
<im_start>assistant
...<im_end>
<im_start>user
...<im_end>
<im_start>assistant
...<im_end>
<im_start>user
...<im_end>
<im_start>assistant
...<im_end>"""
example_encoded = tokenizer(example)

collator([example_encoded["input_ids"]])


Output:

{'input_ids': tensor([[    1,   523,   321, 28730,  2521, 28767,  6574,    13,  1101, 28789,
            321, 28730,   416, 28767,    13, 28789,   321, 28730,  2521, 28767,
           1838,    13,  1101, 28789,   321, 28730,   416, 28767,    13, 28789,
            321, 28730,  2521, 28767,   489, 11143,    13,  1101, 28789,   321,
          28730,   416, 28767,    13, 28789,   321, 28730,  2521, 28767,  1838,
             13,  1101, 28789,   321, 28730,   416, 28767,    13, 28789,   321,
          28730,  2521, 28767,   489, 11143,    13,  1101, 28789,   321, 28730,
            416, 28767,    13, 28789,   321, 28730,  2521, 28767,  1838,    13,
           1101, 28789,   321, 28730,   416, 28767,    13, 28789,   321, 28730,
           2521, 28767,   489, 11143,    13,  1101, 28789,   321, 28730,   416,
          28767]]),
 'labels': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
           -100,  -100,  -100,  -100,  -100,  1101, 28789,   321, 28730,   416,
          28767]])}


You will notice that the loss calculation begins only at the last response_template (which is <im_start>assistant\n in ChatML format), with the rest being filled with -100.


Referenes


Read More

Leave a Reply