Last Updated on 2024-01-03 by Clay
介紹
監督式微調(Supervised Fine-tuning, SFT)是當前訓練大型語言模型(Large Language Model, LLM)最知名的方法之一,本質上與傳統的語言模型建模(language modeling)相同,是讓模型透過訓練資料去學習某些知識。
唯一不同之處在於本來的語言模型建模可能會學習完整的文本,換言之就是真正的文字接龍;而現在所稱的 SFT,則大部分都是指讓模型只學習『聊天』(Chatting)的部份。
也就是說以前的訓練資料可能是:
今天天氣真好,我想要出門去玩......
而現在的訓練資料則是:
### Question: 你今天有什麼計畫?
### Rresponse: 今天天氣很好,我應該會出門去玩。
並且,模型只專注學習今天天氣很好,我應該會出門去玩。的這個部份,前面的輸入都不參與訓練。
程式碼介紹
這裡我紀錄的訓練程式碼是使用 trl 函式庫中提供的 SFTTrainer()
,這個選擇在現在似乎相當熱門,畢竟封裝得很好,可以以輕鬆易懂的邏輯搭建程式碼行數大概 100 行左右的訓練腳本。
而我選用的模型是 Mistral,嚴格說起來是 Mistral 微調過後的 teknium/OpenHermes-2.5-Mistral-7B。在使用 SFTTrainer() 時需要注意的是 tokenizer 的 padding 問題([已解決] Mistral 經過 SFTTrainer 微調後不會輸出 eos_token `<|im_end|>`)。
首先是匯入所需要的套件,並且在最上方我初始化了我的 prompt 格式,這裡是遵守著 ChatML 的格式。
# coding: utf-8
from datasets import load_dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
PROMPT_TEMPLATE = """<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
{chosen}<|im_end|>"""
接下來則是 formatting_prompt_func() 的函式,這個函式等等會一併傳入 SFTTrainer() 去使用,它會自動把訓練資料調整成我想要的格式 —— 也就是一開始定義的 ChatML 格式。
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example["system"])):
text = PROMPT_TEMPLATE.format(
system=example["system"][i],
question=example["question"][i],
chosen=example["chosen"][i],
)
output_texts.append(text)
return output_texts
接下來就一鏡到底了,下方直接是我的 main function。我依序定義了:
- training_args:包含 batch_size、learning_rate、迭代次數、評估步數、優化器... 全部都在這裡直接定義好
- bnb_config:這是量化的設定,因為我打算使用 QLoRA([論文閱讀] QLoRA: Efficient Finetuning of Quantized LLMs)的方式進行訓練,所以等等需要讓 transformer 架構的模型讀取這個設定檔,好讓模型權重轉換成 NF4 的資料型態
- peft_config:這裡是 LoRA 的設定,包括低秩矩陣中間的 rank 要設多少、模型哪幾層要加上 adapter... 等等都是在這裡設定
- dataset:沒什麼好說的,就是訓練資料跟驗證資料
更下面就更單純了,不外乎就是模型、斷詞器、SFTTrainer() 本身... 定義好之後就可以直接訓練。真的非常方便。
不過當然,嘗試各種組合、迭代優化自己的資料集或許才是最重要的呢。
def main() -> None:
# Training arguments
training_args = TrainingArguments(
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=5e-5,
lr_scheduler_type="cosine",
max_steps=6000,
evaluation_strategy="steps",
save_strategy="steps",
do_eval=True,
eval_steps=100,
save_steps=100,
logging_steps=1,
output_dir="outputs_pretraining_20240101",
optim="paged_adamw_32bit",
warmup_steps=100,
remove_unused_columns=False,
bf16=True,
report_to="none",
)
# BitsAndBytes config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
)
# LoRA config
peft_config = LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"k_proj",
"gate_proj",
"v_proj",
"up_proj",
"q_proj",
"o_proj",
"down_proj",
],
)
# Load data
train_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_train_dataset.csv")
eval_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_test_dataset.csv")
# Load model and tokenizer
pretrained_model_name_or_path = "/tmp2/share_data/teknium--OpenHermes-2.5-Mistral-7B/"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
load_in_4bit=True,
quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "left"
# Preprocessing
response_template = "<|im_start|>assistant\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
# Create SFTTrainer
sft_trainer = SFTTrainer(
model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
formatting_func=formatting_prompts_func,
data_collator=collator,
peft_config=peft_config,
neftune_noise_alpha=5,
max_seq_length=4096,
)
# Train
sft_trainer.train()
# Save
sft_trainer.model.save_pretrained("final_checkpoint")
tokenizer.save_pretrained("final_checkpoint")
完整程式碼
# coding: utf-8
from datasets import load_dataset
from peft import LoraConfig
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
PROMPT_TEMPLATE = """<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
{chosen}<|im_end|>"""
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example["system"])):
text = PROMPT_TEMPLATE.format(
system=example["system"][i],
question=example["question"][i],
chosen=example["chosen"][i],
)
output_texts.append(text)
return output_texts
def main() -> None:
# Training arguments
training_args = TrainingArguments(
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=5e-5,
lr_scheduler_type="cosine",
max_steps=6000,
evaluation_strategy="steps",
save_strategy="steps",
do_eval=True,
eval_steps=100,
save_steps=100,
logging_steps=1,
output_dir="outputs_pretraining_20240101",
optim="paged_adamw_32bit",
warmup_steps=100,
remove_unused_columns=False,
bf16=True,
report_to="none",
)
# BitsAndBytes config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=False,
)
# LoRA config
peft_config = LoraConfig(
r=16,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"k_proj",
"gate_proj",
"v_proj",
"up_proj",
"q_proj",
"o_proj",
"down_proj",
],
)
# Load data
train_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_train_dataset.csv")
eval_dataset = load_dataset("csv", split="train", data_files="./program_data/kg_gpt4_sft_test_dataset.csv")
# Load model and tokenizer
pretrained_model_name_or_path = "/tmp2/share_data/teknium--OpenHermes-2.5-Mistral-7B/"
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
load_in_4bit=True,
quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.padding_side = "left"
# Preprocessing
response_template = "<|im_start|>assistant\n"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
# Create SFTTrainer
sft_trainer = SFTTrainer(
model,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
formatting_func=formatting_prompts_func,
data_collator=collator,
peft_config=peft_config,
neftune_noise_alpha=5,
max_seq_length=4096,
)
# Train
sft_trainer.train()
# Save
sft_trainer.model.save_pretrained("final_checkpoint")
tokenizer.save_pretrained("final_checkpoint")
if __name__ == "__main__":
main()
References
- https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B
- huggingface/trl - Transformer Reinforcement Learning