Skip to content

Supervised Fine-tuning Trainer (SFTTrainer) 訓練筆記

Last Updated on 2024-01-03 by Clay

介紹

監督式微調Supervised Fine-tuning, SFT)是當前訓練大型語言模型Large Language Model, LLM)最知名的方法之一,本質上與傳統的語言模型建模(language modeling)相同,是讓模型透過訓練資料去學習某些知識。

唯一不同之處在於本來的語言模型建模可能會學習完整的文本,換言之就是真正的文字接龍;而現在所稱的 SFT,則大部分都是指讓模型只學習『聊天』(Chatting)的部份。

也就是說以前的訓練資料可能是:

今天天氣真好,我想要出門去玩......


而現在的訓練資料則是:

### Question: 你今天有什麼計畫?
### Rresponse: 今天天氣很好,我應該會出門去玩。

並且,模型只專注學習今天天氣很好,我應該會出門去玩。的這個部份,前面的輸入都不參與訓練。


程式碼介紹

這裡我紀錄的訓練程式碼是使用 trl 函式庫中提供的 SFTTrainer(),這個選擇在現在似乎相當熱門,畢竟封裝得很好,可以以輕鬆易懂的邏輯搭建程式碼行數大概 100 行左右的訓練腳本。

而我選用的模型是 Mistral,嚴格說起來是 Mistral 微調過後的 teknium/OpenHermes-2.5-Mistral-7B。在使用 SFTTrainer() 時需要注意的是 tokenizer 的 padding 問題([已解決] Mistral 經過 SFTTrainer 微調後不會輸出 eos_token `<|im_end|>`)。

首先是匯入所需要的套件,並且在最上方我初始化了我的 prompt 格式,這裡是遵守著 ChatML 的格式。

# coding: utf-8
from datasets import load_dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM


PROMPT_TEMPLATE = """<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
{chosen}<|im_end|>"""


接下來則是 formatting_prompt_func() 的函式,這個函式等等會一併傳入 SFTTrainer() 去使用,它會自動把訓練資料調整成我想要的格式 —— 也就是一開始定義的 ChatML 格式。

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example["system"])):
        text = PROMPT_TEMPLATE.format(
            system=example["system"][i],
            question=example["question"][i],
            chosen=example["chosen"][i],
        )

        output_texts.append(text)

    return output_texts


接下來就一鏡到底了,下方直接是我的 main function。我依序定義了:

  • training_args:包含 batch_size、learning_rate、迭代次數、評估步數、優化器... 全部都在這裡直接定義好
  • bnb_config:這是量化的設定,因為我打算使用 QLoRA([論文閱讀] QLoRA: Efficient Finetuning of Quantized LLMs)的方式進行訓練,所以等等需要讓 transformer 架構的模型讀取這個設定檔,好讓模型權重轉換成 NF4 的資料型態
  • peft_config:這裡是 LoRA 的設定,包括低秩矩陣中間的 rank 要設多少、模型哪幾層要加上 adapter... 等等都是在這裡設定
  • dataset:沒什麼好說的,就是訓練資料跟驗證資料

更下面就更單純了,不外乎就是模型、斷詞器、SFTTrainer() 本身... 定義好之後就可以直接訓練。真的非常方便。

不過當然,嘗試各種組合、迭代優化自己的資料集或許才是最重要的呢。

def main() -> None:
    # Training arguments
    training_args = TrainingArguments(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        max_steps=6000,
        evaluation_strategy="steps",
        save_strategy="steps",
        do_eval=True,
        eval_steps=100,
        save_steps=100,
        logging_steps=1,
        output_dir="outputs_pretraining_20240101",
        optim="paged_adamw_32bit",
        warmup_steps=100,
        remove_unused_columns=False,
        bf16=True,
        report_to="none",
    )
    
    # BitsAndBytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False,
    )

    # LoRA config
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "k_proj",
            "gate_proj",
            "v_proj",
            "up_proj",
            "q_proj",
            "o_proj",
            "down_proj",
        ],
    )

    # Load data
    train_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_train_dataset.csv")
    eval_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_test_dataset.csv")

    # Load model and tokenizer
    pretrained_model_name_or_path = "/tmp2/share_data/teknium--OpenHermes-2.5-Mistral-7B/"
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,
        quantization_config=bnb_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    # Preprocessing
    response_template = "<|im_start|>assistant\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

    # Create SFTTrainer
    sft_trainer = SFTTrainer(
        model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        peft_config=peft_config,
        neftune_noise_alpha=5,
        max_seq_length=4096,
    )

    # Train
    sft_trainer.train()

    # Save
    sft_trainer.model.save_pretrained("final_checkpoint")
    tokenizer.save_pretrained("final_checkpoint")

完整程式碼

# coding: utf-8
from datasets import load_dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM


PROMPT_TEMPLATE = """<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
{chosen}<|im_end|>"""


def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example["system"])):
        text = PROMPT_TEMPLATE.format(
            system=example["system"][i],
            question=example["question"][i],
            chosen=example["chosen"][i],
        )

        output_texts.append(text)

    return output_texts


def main() -> None:
    # Training arguments
    training_args = TrainingArguments(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        max_steps=6000,
        evaluation_strategy="steps",
        save_strategy="steps",
        do_eval=True,
        eval_steps=100,
        save_steps=100,
        logging_steps=1,
        output_dir="outputs_pretraining_20240101",
        optim="paged_adamw_32bit",
        warmup_steps=100,
        remove_unused_columns=False,
        bf16=True,
        report_to="none",
    )
    
    # BitsAndBytes config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=False,
    )

    # LoRA config
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "k_proj",
            "gate_proj",
            "v_proj",
            "up_proj",
            "q_proj",
            "o_proj",
            "down_proj",
        ],
    )

    # Load data
    train_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_train_dataset.csv")
    eval_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_test_dataset.csv")

    # Load model and tokenizer
    pretrained_model_name_or_path = "/tmp2/share_data/teknium--OpenHermes-2.5-Mistral-7B/"
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path=pretrained_model_name_or_path,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        load_in_4bit=True,
        quantization_config=bnb_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.padding_side = "left"

    # Preprocessing
    response_template = "<|im_start|>assistant\n"
    collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

    # Create SFTTrainer
    sft_trainer = SFTTrainer(
        model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        peft_config=peft_config,
        neftune_noise_alpha=5,
        max_seq_length=4096,
    )

    # Train
    sft_trainer.train()

    # Save
    sft_trainer.model.save_pretrained("final_checkpoint")
    tokenizer.save_pretrained("final_checkpoint")


if __name__ == "__main__":
    main()

References


Read More

Leave a Reply