Skip to content

Supervised Fine-tuning Trainer (SFTTrainer) Note

Last Updated on 2024-08-02 by Clay

Introduction

Supervised Fine-Tuning (SFT) is one of the most well-known methods for training Large Language Models (LLM). Essentially, it is similar to traditional language modeling, where the model learns certain knowledge through training data.

The only difference is that traditional language modeling may involve learning entire texts, which is like a genuine text completion task. In contrast, what is referred to as SFT nowadays mainly involves training the model to learn only the chatting portion.

This means that the previous training data might be:

The weather is great today, I want to go out and have fun...


While the current training data is:

### Question: What are your plans for today?
### Response: The weather is great today, I should go out and have fun.

And the model focuses only on learning the part The weather is great today, I should go out and have fun., with the preceding input not being part of the training.


Code Introduction

Here, I document the training code using the trl library's SFTTrainer(). This choice seems quite popular nowadays because it is well-packaged, allowing for the creation of a training script of around 100 lines with straightforward logic.

The model I chose is Mistral, specifically the fine-tuned teknium/OpenHermes-2.5-Mistral-7B. When using SFTTrainer(), one must pay attention to the tokenizer's padding issue ([Solved] Mistral does not output eos_token `<|im_end|>` after fine-tuning with SFTTrainer).

First, import the necessary packages, and at the top, I initialized my prompt format, adhering to the ChatML format.

# coding: utf-8
from datasets import load_dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM


PROMPT_TEMPLATE = """<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
{chosen}<|im_end|>"""


Next is the formatting_prompts_func() function. This function will be passed into SFTTrainer() later and will automatically adjust the training data into the desired format — the ChatML format defined at the beginning.

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example["system"])):
        text = PROMPT_TEMPLATE.format(
            system=example["system"][i],
            question=example["question"][i],
            chosen=example["chosen"][i],
        )

        output_texts.append(text)

    return output_texts


Next is the main function. I sequentially defined:

  • training_args: This includes batch_size, learning_rate, number of iterations, evaluation steps, optimizer... all defined here.
  • bnb_config: This is the quantization setting. Since I plan to train using QLoRA ([Paper Reading] QLoRA: Efficient Finetuning of Quantized LLMs), the transformer model needs to read this configuration file to convert the model weights to the NF4 data type.
  • peft_config: This is the LoRA setting, including the rank in the low-rank matrix, which layers of the model to add the adapter to... all set here.
  • dataset: Just the training and validation data.

The following part is simpler, involving the model, tokenizer, and SFTTrainer() itself. Once everything is defined, training can start directly. It is really convenient.

Of course, trying various combinations and iterating to optimize one's dataset might be the most important.

def main() -> None:
    # Training arguments
    training_args = TrainingArguments(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        max_steps=6000,
        evaluation_strategy="steps",
        save_strategy="steps",
        do_eval=True,
        eval_steps=100,
        save_steps=100,
        logging_steps=1,
        output_dir="outputs_pretraining_20240101",
        optim="paged_adamw_32bit",
        warmup_steps=100,
        remove_unused_columns=False,
        bf16=True,
        report_to="none",
    )
    
    # BitsAndBytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False,
    )

    # LoRA config
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "k_proj",
            "gate_proj",
            "v_proj",
            "up_proj",
            "q_proj",
            "o_proj",
            "down_proj",
        ],
    )

    # Load data
    train_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_train_dataset.csv")
    eval_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_test_dataset.csv")

    # Load model and tokenizer
    pretrained_model_name_or_path = "/tmp2/share_data/teknium--OpenHermes-2.5-Mistral-7B/"
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,
        quantization_config=bnb_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    # Preprocessing
    response_template = "<|im_start|>assistant\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

    # Create SFTTrainer
    sft_trainer = SFTTrainer(
        model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        peft_config=peft_config,
        neftune_noise_alpha=5,
        max_seq_length=4096,
    )

    # Train
    sft_trainer.train()

    # Save
    sft_trainer.model.save_pretrained("final_checkpoint")
    tokenizer.save_pretrained("final_checkpoint")


if __name__ == "__main__":
    main()

Complete Code

# coding: utf-8
from datasets import load_dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM


PROMPT_TEMPLATE = """<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
{chosen}<|im_end|>"""


def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example["system"])):
        text = PROMPT_TEMPLATE.format(
            system=example["system"][i],
            question=example["question"][i],
            chosen=example["chosen"][i],
        )

        output_texts.append(text)

    return output_texts


def main() -> None:
    # Training arguments
    training_args = TrainingArguments(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        max_steps=6000,
        evaluation_strategy="steps",
        save_strategy="steps",
        do_eval=True,
        eval_steps=100,
        save_steps=100,
        logging_steps=1,
        output_dir="outputs_pretraining_20240101",
        optim="paged_adamw_32bit",
        warmup_steps=100,
        remove_unused_columns=False,
        bf16=True,
        report_to="none",
    )
    
    # BitsAndBytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False,
    )

    # LoRA config
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "k_proj",
            "gate_proj",
            "v_proj",
            "up_proj",
            "q_proj",
            "o_proj",
            "down_proj",
        ],
    )

    # Load data
    train_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_train_dataset.csv")
    eval_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_test_dataset.csv")

    # Load model and tokenizer
    pretrained_model_name_or_path = "/tmp2/share_data/teknium--OpenHermes-2.5-Mistral-7B/"
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,
        quantization_config=bnb_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    # Preprocessing
    response_template = "<|im_start|>assistant\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

    # Create SFTTrainer
    sft_trainer = SFTTrainer(
        model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        peft_config=peft_config,
        neftune_noise_alpha=5,
        max_seq_length=4096,
    )

    # Train
    sft_trainer.train()

    # Save
    sft_trainer.model.save_pretrained("final_checkpoint")
    tokenizer.save_pretrained("final_checkpoint")


if __name__ == "__main__":
    main()

References


Read More

Leave a Reply