Skip to content

Supervised Fine-tuning Trainer (SFTTrainer) 訓練筆記

Last Updated on 2024-01-03 by Clay

介紹

監督式微調Supervised Fine-tuning, SFT)是當前訓練大型語言模型Large Language Model, LLM)最知名的方法之一,本質上與傳統的語言模型建模(language modeling)相同,是讓模型透過訓練資料去學習某些知識。

唯一不同之處在於本來的語言模型建模可能會學習完整的文本,換言之就是真正的文字接龍;而現在所稱的 SFT,則大部分都是指讓模型只學習『聊天』(Chatting)的部份。

也就是說以前的訓練資料可能是:

今天天氣真好,我想要出門去玩......


而現在的訓練資料則是:

### Question: 你今天有什麼計畫?
### Rresponse: 今天天氣很好,我應該會出門去玩。

並且,模型只專注學習今天天氣很好,我應該會出門去玩。的這個部份,前面的輸入都不參與訓練。


程式碼介紹

這裡我紀錄的訓練程式碼是使用 trl 函式庫中提供的 SFTTrainer(),這個選擇在現在似乎相當熱門,畢竟封裝得很好,可以以輕鬆易懂的邏輯搭建程式碼行數大概 100 行左右的訓練腳本。

而我選用的模型是 Mistral,嚴格說起來是 Mistral 微調過後的 teknium/OpenHermes-2.5-Mistral-7B。在使用 SFTTrainer() 時需要注意的是 tokenizer 的 padding 問題([已解決] Mistral 經過 SFTTrainer 微調後不會輸出 eos_token `<|im_end|>`)。

首先是匯入所需要的套件,並且在最上方我初始化了我的 prompt 格式,這裡是遵守著 ChatML 的格式。

# coding: utf-8
from datasets import load_dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM


PROMPT_TEMPLATE = """<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
{chosen}<|im_end|>"""


接下來則是 formatting_prompt_func() 的函式,這個函式等等會一併傳入 SFTTrainer() 去使用,它會自動把訓練資料調整成我想要的格式 —— 也就是一開始定義的 ChatML 格式。

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example["system"])):
        text = PROMPT_TEMPLATE.format(
            system=example["system"][i],
            question=example["question"][i],
            chosen=example["chosen"][i],
        )

        output_texts.append(text)

    return output_texts


接下來就一鏡到底了,下方直接是我的 main function。我依序定義了:

  • training_args:包含 batch_size、learning_rate、迭代次數、評估步數、優化器... 全部都在這裡直接定義好
  • bnb_config:這是量化的設定,因為我打算使用 QLoRA([論文閱讀] QLoRA: Efficient Finetuning of Quantized LLMs)的方式進行訓練,所以等等需要讓 transformer 架構的模型讀取這個設定檔,好讓模型權重轉換成 NF4 的資料型態
  • peft_config:這裡是 LoRA 的設定,包括低秩矩陣中間的 rank 要設多少、模型哪幾層要加上 adapter... 等等都是在這裡設定
  • dataset:沒什麼好說的,就是訓練資料跟驗證資料

更下面就更單純了,不外乎就是模型、斷詞器、SFTTrainer() 本身... 定義好之後就可以直接訓練。真的非常方便。

不過當然,嘗試各種組合、迭代優化自己的資料集或許才是最重要的呢。

def main() -> None:
    # Training arguments
    training_args = TrainingArguments(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        max_steps=6000,
        evaluation_strategy="steps",
        save_strategy="steps",
        do_eval=True,
        eval_steps=100,
        save_steps=100,
        logging_steps=1,
        output_dir="outputs_pretraining_20240101",
        optim="paged_adamw_32bit",
        warmup_steps=100,
        remove_unused_columns=False,
        bf16=True,
        report_to="none",
    )
    
    # BitsAndBytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False,
    )

    # LoRA config
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "k_proj",
            "gate_proj",
            "v_proj",
            "up_proj",
            "q_proj",
            "o_proj",
            "down_proj",
        ],
    )

    # Load data
    train_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_train_dataset.csv")
    eval_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_test_dataset.csv")

    # Load model and tokenizer
    pretrained_model_name_or_path = "/tmp2/share_data/teknium--OpenHermes-2.5-Mistral-7B/"
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,
        quantization_config=bnb_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    # Preprocessing
    response_template = "<|im_start|>assistant\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

    # Create SFTTrainer
    sft_trainer = SFTTrainer(
        model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        peft_config=peft_config,
        neftune_noise_alpha=5,
        max_seq_length=4096,
    )

    # Train
    sft_trainer.train()

    # Save
    sft_trainer.model.save_pretrained("final_checkpoint")
    tokenizer.save_pretrained("final_checkpoint")

完整程式碼

# coding: utf-8
from datasets import load_dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM


PROMPT_TEMPLATE = """<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
{chosen}<|im_end|>"""


def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example["system"])):
        text = PROMPT_TEMPLATE.format(
            system=example["system"][i],
            question=example["question"][i],
            chosen=example["chosen"][i],
        )

        output_texts.append(text)

    return output_texts


def main() -> None:
    # Training arguments
    training_args = TrainingArguments(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        max_steps=6000,
        evaluation_strategy="steps",
        save_strategy="steps",
        do_eval=True,
        eval_steps=100,
        save_steps=100,
        logging_steps=1,
        output_dir="outputs_pretraining_20240101",
        optim="paged_adamw_32bit",
        warmup_steps=100,
        remove_unused_columns=False,
        bf16=True,
        report_to="none",
    )
    
    # BitsAndBytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False,
    )

    # LoRA config
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "k_proj",
            "gate_proj",
            "v_proj",
            "up_proj",
            "q_proj",
            "o_proj",
            "down_proj",
        ],
    )

    # Load data
    train_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_train_dataset.csv")
    eval_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_test_dataset.csv")

    # Load model and tokenizer
    pretrained_model_name_or_path = "/tmp2/share_data/teknium--OpenHermes-2.5-Mistral-7B/"
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,
        quantization_config=bnb_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    # Preprocessing
    response_template = "<|im_start|>assistant\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

    # Create SFTTrainer
    sft_trainer = SFTTrainer(
        model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        peft_config=peft_config,
        neftune_noise_alpha=5,
        max_seq_length=4096,
    )

    # Train
    sft_trainer.train()

    # Save
    sft_trainer.model.save_pretrained("final_checkpoint")
    tokenizer.save_pretrained("final_checkpoint")


if __name__ == "__main__":
    main()

References


Read More

2 thoughts on “Supervised Fine-tuning Trainer (SFTTrainer) 訓練筆記”

  1. 想请教一个问题:关于formatting_prompt_func() 的函式,传进去的参数example,到底是一个训练example,还是一个batch的训练example,因为我看到huggingface trl关于SFTTrainer那一章节,有两种这样的function
    (1) def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example[‘instruction’])):
    text = f”### Question: {example[‘instruction’][i]}\n ### Answer: {example[‘output’][i]}”
    output_texts.append(text)
    return output_texts

    def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example[‘question’])):
    text = f”### Question: {example[‘question’][i]}\n ### Answer: {example[‘answer’][i]}”
    output_texts.append(text)
    return output_texts

    (2) def formatting_func(example):
    text = f”### Question: {example[‘question’]}\n ### Answer: {example[‘answer’]}”
    return text

    sft_config = SFTConfig(packing=True)
    trainer = SFTTrainer(
    “facebook/opt-350m”,
    train_dataset=dataset,
    args=sft_config,
    formatting_func=formatting_func
    )

    1. 這是因為是否有使用 `packing=True` 造成的差異。

      當 `packing` 啟用時,通常意味著我們不希望因為長短不一的訓練資料讓 GPU 需要額外空間紀錄 padding、進而造成浪費,所以才把多筆不同的訓練資料拼接,之後才在真正需要訓練時解開成不同筆資料 —— 在 SFTTrainer 中的實現為會幫忙依序傳入處理的 `formatting_func()` 中,所以 `formatting_func()` 只需要處理『單筆資料』的情況。

      而預設情況 `packing` 是不啟用的,而 SFTTrainer 中的實現是一個 batch 為單位進行處理,所以我們需要寫成 for-loop 去依序處理資料。

      簡單來說,就是看 `packing` 參數:`True` 時處理單筆資料、`False` 處理 batch 為單位的資料。

      我自己在看 trl==0.13.0.dev0 的原始碼時,是可以看到 `self._prepare_non_packed_dataloader()` 和 `self._prepare_packed_dataloader()` 兩種不同邏輯的實現的,可以參考看看。

      不過,因為 trl 一直在不斷開發完善(我也有微薄地貢獻過一些 PR XDD),所以現在版本可能與當初看的時候不一樣了,最好隨時注意原始碼的改動。我當初發 PR 時最大的體悟就是他們開發得很快、功能上得很齊全但是文件更新的速度比較慢~

Leave a Reply