Last Updated on 2024-10-19 by Clay
I've been intermittently reading about a fine-tuning method called Kahneman-Tversky Optimization (KTO) from various sources like HuggingFace's official documents and other online materials. It's similar to DPO as a way to align models with human values, but KTO's data preparation format is much more convenient, so I'm quickly applying it to my current tasks before making time to study the detailed content in the related papers.
However, unlike DPO and RLHF, KTO claims that depending on the performance of your base model, you can even skip the SFT stage. Additionally, you don’t need to use paired data like in DPO; instead, you can use unbalanced labeled data. This makes it particularly suitable for fine-tuning models like chatbots, where user feedback consists of simple options like 'thumbs up' or 'thumbs down'.
For example, in DPO data preparation, a single data point requires a (prompt, good data, bad data) triplet. But the KTO training data format is:
kto_dataset_dict = {
"prompt": [
"Hey, hello",
"How are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"completion": [
"hi nice to meet you",
"leave me alone",
"I don't have a name",
"My name is Mary",
"Python",
"C++",
"Java",
],
"label": [
True,
False,
False,
True,
True,
False,
False,
],
}
The prompt
is the input given to the model, the completion
is the model's generated output, and the label
is a binary label indicating whether the response is 'good' or 'bad'.
So, how is KTO's loss function designed?
Let's break down the individual elements of this formula.
is the weight for preferred responses is the weight for non-preferred responses is the scaling parameter, adjusting the sensitivity to changes is the Sigmoid smoothing function is the KL divergence between two distributions, measuring the distance between the current model and the reference model measures the log difference between the current model's predictive distribution and the reference distribution , indicating the model's relative preference compared to the reference model
The term
Example Code
First, let's look at the format of the training data:
from datasets import load_dataset
dataset = load_dataset("trl-lib/kto-mix-14k")
dataset["train"][3]
Output:
{'prompt': [{'content': 'Q:Information: - The Assistant Secretary of Defense for Health Affairs (ASD(HA)) is chartered under United States Department of Defense Directive (DoDD) 5136.1 in 1994. This DoDD states that the ASD(HA) is the principal advisor to the U.S. Secretary of Defense on all "DoD health policies, programs and activities." In addition to exercising oversight of all DoD health resources, ASD(HA) serves as director of the Tricare Management Activity. - The Department of the Air Force (DAF) is one of the three Military Departments within the Department of Defense of the United States of America. The Department of the Air Force was formed on September 18, 1947, per the National Security Act of 1947 and it includes all elements and units of the United States Air Force (USAF). - The Surgeon General of the Air Force is the senior-most Medical Service officer in the United States Department of the Air Force. In recent times, this has been a Lieutenant General who serves as head of the United States Air Force Medical Service (AFMS). The Surgeon General is usually the senior Medical Corps officer, but acting surgeons general have been from other branches of the medical service. - Lieutenant general, lieutenant-general and similar (abbrev Lt Gen, LTG and similar) is a three-star military rank (NATO code OF-8) used in many countries. The rank traces its origins to the Middle Ages, where the title of lieutenant general was held by the second in command on the battlefield, who was normally subordinate to a captain general. - The United States Air Force (USAF) is the aerial warfare service branch of the United States Armed Forces and one of the seven American uniformed services. Initially part of the United States Army, the USAF was formed as a separate branch of the military on 18 September 1947 under the National Security Act of 1947. It is the most recent branch of the U.S. military to be formed, and is the largest and one of the world\'s most technologically advanced air forces. The USAF articulates its core functions as Nuclear Deterrence Operations, Special Operations, Air Superiority, Global Integrated ISR, Space Superiority, Command and Control, Cyberspace Superiority, Personnel Recovery, Global Precision Attack, Building Partnerships, Rapid Global Mobility and Agile Combat Support. - Lieutenant General James Gordon Roudebush , USAF , ( born February 24 , 1948 ) was the 19th Surgeon General of the United States Air Force , Headquarters U.S. Air Force , Washington , D.C. General Roudebush served as functional manager of the U.S. Air Force Medical Service . In this capacity , he advised the Secretary of the Air Force and Air Force Chief of Staff , as well as the Assistant Secretary of Defense for Health Affairs on matters pertaining to the medical aspects of the air expeditionary force and the health of Air Force people . General Roudebush had authority to commit resources worldwide for the Air Force Medical Service , to make decisions affecting the delivery of medical services , and to develop plans , programs and procedures to support worldwide medical service missions . He exercised direction , guidance and technical management of more than 42,400 people assigned to 74 medical facilities worldwide . A native of Gering , Nebraska , Roudebush entered the Air Force in 1975 after receiving a Bachelor of Medicine degree from the University of Nebraska at Lincoln , and a Doctor of Medicine degree from the University of Nebraska College of Medicine . He completed residency training in family practice at the Wright - Patterson Air Force Medical Center , Ohio , in 1978 , and aerospace medicine at Brooks Air Force Base , Texas , in 1984 . He commanded a wing clinic and wing hospital before becoming Deputy Commander of the Air Force Materiel Command Human Systems Center . He has served as Command Surgeon for U.S. Central Command , Pacific Air Forces , U.S. Transportation Command and Headquarters Air Mobility Command . Prior to his selection as the 19th Surgeon General , he served as the Deputy Surgeon General of the U.S. Air Force . He retired from the U.S. Air Force on October 1 , 2009 . After reading the paragraphs above, choose the best answer for the entity that related to \'james g. roudebush\' with the relationship of \'occupation\'. Choices: - advisor - army - captain - general - lieutenant - military - officer - secretary - surgeon - united states of america\nA:',
'role': 'user'}],
'completion': [{'content': ' Based on the information provided, the best answer for the entity related to "James G. Roudebush" with the relationship of "occupation" is:\n\nB. Officer\n\nAccording to the text, James G. Roudebush served as the 19th Surgeon General of the United States Air Force, which is an officer position.',
'role': 'assistant'}],
'label': False}
With this format, you can fine-tune your model.
Below is the example code provided by HuggingFace. Interestingly, I saw someone raising an issue about running errors, but I didn’t encounter any. Of course, I swapped out the model and dataset. (Updated the next day: I indeed encountered an error during training. I initially tried to fix it myself but saw the developers at HuggingFace were already working on a fix, which involved changes to the internal trl code, so I decided to wait)
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
# Full training:
python examples/scripts/kto.py \
--dataset_name trl-lib/kto-mix-14k \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 16 \
--num_train_epochs 1 \
--learning_rate 5e-7 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir=kto-aligned-model \
--warmup_ratio 0.1 \
--report_to wandb \
--bf16 \
--logging_first_step
# QLoRA:
python examples/scripts/kto.py \
--dataset_name trl-lib/kto-mix-14k \
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
--per_device_train_batch_size 8 \
--num_train_epochs 1 \
--learning_rate 5e-7 \
--lr_scheduler_type=cosine \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir=kto-aligned-model-lora \
--warmup_ratio 0.1 \
--report_to wandb \
--bf16 \
--logging_first_step \
--use_peft \
--load_in_4bit \
--lora_target_modules=all-linear \
--lora_r=16 \
--lora_alpha=16
"""
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import (
KTOConfig,
KTOTrainer,
ModelConfig,
ScriptArguments,
get_peft_config,
maybe_unpair_preference_dataset,
setup_chat_format,
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
# Load a pretrained model
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
ref_model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# If we are aligning a base model, we use ChatML as the default template
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
# Load the dataset
dataset = load_dataset(script_args.dataset_name)
# If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)
# Apply chat template
def format_dataset(example):
if isinstance(example["completion"], str):
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
else:
example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False)
example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False)
return example
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc)
# Initialize the KTO trainer
trainer = KTOTrainer(
model,
ref_model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split],
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
Now, looking forward to seeing the results after training! If I have time tomorrow, I might also review the more detailed theoretical explanations in the KTO paper. While I understand the formula's workings, I haven't fully grasped why it is designed this way and why it proves effective.
References
- https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py
- KTO: Model Alignment as Prospect Theoretic Optimization