Skip to content

KTOTrainer(Kahneman-Tversky Optimization Trainer)筆記

Last Updated on 2024-10-19 by Clay

之前一直斷斷續續從網路上、HuggingFace 官方文件等各種地方看到名為 Kahneman-Tversky OptimizationKTO)的 LLM 微調方式(實際上跟 DPO 相似是一種價值觀對齊方式),因為其準備資料的格式比起 DPO 實在太方便了,目前先趕緊嘗試應用在自己手邊的任務上、之後再來抽空閱讀論文中的詳細內容。

不過與 DPO、RLHF 不同的是,根據你使用的基底模型(base model)的性能,KTO 宣稱你甚至可以跳過 SFT 的階段,並且不需要像 DPO 一樣使用成對的資料、可以直接使用不平衡標籤的資料。這就很適合拿來調整收集使用者回饋的聊天機器人模型,比方說使用者的界面上只有『讚』跟『倒讚』的選項。

舉例來說,在 DPO 的資料準備上,一筆資料我們需要準備 (prompt, good data, bad data) 這樣的三元組;但是 KTO 的訓練資料格式則是:

kto_dataset_dict = {
    "prompt": [
        "Hey, hello",
        "How are you",
        "What is your name?",
        "What is your name?",
        "Which is the best programming language?",
        "Which is the best programming language?",
        "Which is the best programming language?",
    ],
    "completion": [
        "hi nice to meet you",
        "leave me alone",
        "I don't have a name",
        "My name is Mary",
        "Python",
        "C++",
        "Java",
    ],
    "label": [
        True,
        False,
        False,
        True,
        True,
        False,
        False,
    ],
}


prompt 是我們輸入給模型的提示(輸入),completion 則是模型訓練補完的內容,label 則是我們給予模型『好』或『不好』的二元標籤。

那 KTO 的損失函數究竟是怎麼設計的呢?

我們來拆解一下這個公式的細項。

  • \lambda_{D} 是偏好回答的權重
  • \lambda_{D} 是不偏好回答的權重
  • \beta 是放大參數,調整變化的敏感度
  • \sigma 就是平滑函數 Sigmoid
  • z_{0} 是兩個分佈之間的 KL divergence,用來衡量當前模型和參考模型之間的距離
  • \gamma_{\Theta}(x, y) 這個項目衡量模型的預測分佈 \pi_{\Theta}​ 與參考分佈 \pi_{ref} 之間的對數差異,它描述了當前模型與參考模型的相對偏好程度

v(x, y) 則就是依據資料是正樣本還是負樣本,而使用不同的公式計算損失函數。


範例程式

首先我們看一下訓練資料的格式:

from datasets import load_dataset

dataset = load_dataset("trl-lib/kto-mix-14k")
dataset["train"][3]


Output:

{'prompt': [{'content': 'Q:Information:  - The Assistant Secretary of Defense for Health Affairs (ASD(HA)) is chartered under United States Department of Defense Directive (DoDD) 5136.1 in 1994. This DoDD states that the ASD(HA) is the principal advisor to the U.S. Secretary of Defense on all "DoD health policies, programs and activities." In addition to exercising oversight of all DoD health resources, ASD(HA) serves as director of the Tricare Management Activity.  - The Department of the Air Force (DAF) is one of the three Military Departments within the Department of Defense of the United States of America. The Department of the Air Force was formed on September 18, 1947, per the National Security Act of 1947 and it includes all elements and units of the United States Air Force (USAF).  - The Surgeon General of the Air Force is the senior-most Medical Service officer in the United States Department of the Air Force. In recent times, this has been a Lieutenant General who serves as head of the United States Air Force Medical Service (AFMS). The Surgeon General is usually the senior Medical Corps officer, but acting surgeons general have been from other branches of the medical service.  - Lieutenant general, lieutenant-general and similar (abbrev Lt Gen, LTG and similar) is a three-star military rank (NATO code OF-8) used in many countries. The rank traces its origins to the Middle Ages, where the title of lieutenant general was held by the second in command on the battlefield, who was normally subordinate to a captain general.  - The United States Air Force (USAF) is the aerial warfare service branch of the United States Armed Forces and one of the seven American uniformed services. Initially part of the United States Army, the USAF was formed as a separate branch of the military on 18 September 1947 under the National Security Act of 1947. It is the most recent branch of the U.S. military to be formed, and is the largest and one of the world\'s most technologically advanced air forces. The USAF articulates its core functions as Nuclear Deterrence Operations, Special Operations, Air Superiority, Global Integrated ISR, Space Superiority, Command and Control, Cyberspace Superiority, Personnel Recovery, Global Precision Attack, Building Partnerships, Rapid Global Mobility and Agile Combat Support.  - Lieutenant General James Gordon Roudebush , USAF , ( born February 24 , 1948 ) was the 19th Surgeon General of the United States Air Force , Headquarters U.S. Air Force , Washington , D.C. General Roudebush served as functional manager of the U.S. Air Force Medical Service . In this capacity , he advised the Secretary of the Air Force and Air Force Chief of Staff , as well as the Assistant Secretary of Defense for Health Affairs on matters pertaining to the medical aspects of the air expeditionary force and the health of Air Force people . General Roudebush had authority to commit resources worldwide for the Air Force Medical Service , to make decisions affecting the delivery of medical services , and to develop plans , programs and procedures to support worldwide medical service missions . He exercised direction , guidance and technical management of more than 42,400 people assigned to 74 medical facilities worldwide . A native of Gering , Nebraska , Roudebush entered the Air Force in 1975 after receiving a Bachelor of Medicine degree from the University of Nebraska at Lincoln , and a Doctor of Medicine degree from the University of Nebraska College of Medicine . He completed residency training in family practice at the Wright - Patterson Air Force Medical Center , Ohio , in 1978 , and aerospace medicine at Brooks Air Force Base , Texas , in 1984 . He commanded a wing clinic and wing hospital before becoming Deputy Commander of the Air Force Materiel Command Human Systems Center . He has served as Command Surgeon for U.S. Central Command , Pacific Air Forces , U.S. Transportation Command and Headquarters Air Mobility Command . Prior to his selection as the 19th Surgeon General , he served as the Deputy Surgeon General of the U.S. Air Force . He retired from the U.S. Air Force on October 1 , 2009 .    After reading the paragraphs above, choose the best answer for the entity that related to \'james g. roudebush\' with the relationship of \'occupation\'.  Choices: - advisor  - army  - captain  - general  - lieutenant  - military  - officer  - secretary  - surgeon  - united states of america\nA:',
'role': 'user'}],
'completion': [{'content': ' Based on the information provided, the best answer for the entity related to "James G. Roudebush" with the relationship of "occupation" is:\n\nB. Officer\n\nAccording to the text, James G. Roudebush served as the 19th Surgeon General of the United States Air Force, which is an officer position.',
'role': 'assistant'}],
'label': False}


只要整理出這樣的格式就可以微調模型囉。

下面則是 HuggingFace 所提供的範例程式碼。神奇的是我有看到有人提 Issue 說執行會報錯,但我卻沒有遇到。當然,我有換掉模型跟所使用的資料集。隔日更新,我的訓練確實在後來遇到報錯,本想自行修復,但看到 HuggingFace 上的開發者已經在修復了,還動到 trl 內部的程式碼,於是決定等待

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.

# Full training:
python examples/scripts/kto.py \
    --dataset_name trl-lib/kto-mix-14k \
    --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
    --per_device_train_batch_size 16 \
    --num_train_epochs 1 \
    --learning_rate 5e-7 \
    --lr_scheduler_type=cosine \
    --gradient_accumulation_steps 1 \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir=kto-aligned-model \
    --warmup_ratio 0.1 \
    --report_to wandb \
    --bf16 \
    --logging_first_step

# QLoRA:
python examples/scripts/kto.py \
    --dataset_name trl-lib/kto-mix-14k \
    --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
    --per_device_train_batch_size 8 \
    --num_train_epochs 1 \
    --learning_rate 5e-7 \
    --lr_scheduler_type=cosine \
    --gradient_accumulation_steps 1 \
    --logging_steps 10 \
    --eval_steps 500 \
    --output_dir=kto-aligned-model-lora \
    --warmup_ratio 0.1 \
    --report_to wandb \
    --bf16 \
    --logging_first_step \
    --use_peft \
    --load_in_4bit \
    --lora_target_modules=all-linear \
    --lora_r=16 \
    --lora_alpha=16
"""

from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import (
    KTOConfig,
    KTOTrainer,
    ModelConfig,
    ScriptArguments,
    get_peft_config,
    maybe_unpair_preference_dataset,
    setup_chat_format,
)


if __name__ == "__main__":
    parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_into_dataclasses()

    # Load a pretrained model
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )
    ref_model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # If we are aligning a base model, we use ChatML as the default template
    if tokenizer.chat_template is None:
        model, tokenizer = setup_chat_format(model, tokenizer)

    # Load the dataset
    dataset = load_dataset(script_args.dataset_name)

    # If needed, reformat a DPO-formatted dataset (prompt, chosen, rejected) to a KTO-format (prompt, completion, label)
    dataset = maybe_unpair_preference_dataset(dataset, num_proc=training_args.dataset_num_proc)

    # Apply chat template
    def format_dataset(example):
        if isinstance(example["completion"], str):
            example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
            example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
        else:
            example["prompt"] = tokenizer.apply_chat_template(example["completion"][:-1], tokenize=False)
            example["completion"] = tokenizer.apply_chat_template([example["completion"][-1]], tokenize=False)
        return example

    # Compute that only on the main process for faster data processing.
    # see: https://github.com/huggingface/trl/pull/1255
    with PartialState().local_main_process_first():
        dataset = dataset.map(format_dataset, num_proc=training_args.dataset_num_proc)

    # Initialize the KTO trainer
    trainer = KTOTrainer(
        model,
        ref_model,
        args=training_args,
        train_dataset=dataset[script_args.dataset_train_split],
        eval_dataset=dataset[script_args.dataset_test_split],
        processing_class=tokenizer,
        peft_config=get_peft_config(model_args),
    )

    # Train and push the model to the Hub
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


那麼,期待看到訓練完之後的結果!明天若有空,我應該也會來看看 KTO 論文中比較詳細的理論解釋。目前雖然看懂了公式的作法,但卻沒能理解為什麼要這樣設置、以及為什麼會有效呢。


References


Read More

Leave a Reply