Skip to content

Unsloth 加速微調開源項目筆記

Last Updated on 2024-06-04 by Clay

介紹

這幾個月以來我一直受到 Unsloth 這個項目的照顧,主要是因為我的工作會有很大的一部分牽涉到大型語言模型(LLM)的微調,而微調 LLM 是非常耗時的,除了收集資料外最大的時間成本就是在永無止境地透過 GPU 微調模型。

而 Unsloth 對 AI 開發者的助益就在於,它透過把所有的核心都使用 OpanAI Triton 重構,並手動重寫了不同模型的反向傳播引擎,所以切實地提昇了反向傳播的速度。

不過在微調速度優化的這一亮眼表現下,仍然有一些明確的限制,比方說僅支援特定的模型架構、並不是所有訓練方法都支援(比方說 ORPO 就是後來才加入的)、當前仍然只能使用單片 GPU(截至 2024/06/04 為止仍是如此 )。

當然,當前的主流模型跟主流訓練演算法都是支援的,比如 Llama-3、Mistral、Gemma 等模型架構,並且 SFT、DPO、ORPO 等訓練方法也都支援,是非常有用的工具,普遍都能加速到 1.9x 以上(開發團隊測試)。

以下我就簡單介紹一下 Unsloth。


安裝

安裝分成 Conda 和 pip 兩種,pip 會更複雜一些。不過詳細的步驟,當然還是參閱 GitHub 上的教學最好,連結我會放在最底下的參考資料。

Conda

conda create --name unsloth_env python=3.10
conda activate unsloth_env

conda install pytorch-cuda=<12.1/11.8> pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers

pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

pip install --no-deps trl peft accelerate bitsandbytes



pip

首先需要確認 CUDA 版本。

import torch; torch.version.cuda


接著按照不同的 torch 版本使用不同的安裝指令,以下只舉例 PyTorch 2.1.0。

pip install --upgrade --force-reinstall --no-cache-dir torch==2.1.0 triton \
  --index-url https://download.pytorch.org/whl/cu121

# According your cuda version and install the correspond version
pip install "unsloth[cu118] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu118-ampere] @ git+https://github.com/unslothai/unsloth.git"
pip install "unsloth[cu121-ampere] @ git+https://github.com/unslothai/unsloth.git"


接下來就是看執行時還缺什麼套件了,當然大方向就是『缺什麼裝什麼』。

另外,還有一個我推薦的使用方法 —— 利用他人包好的 docker image 來建構自己的 Unsloth 訓練環境。

首先建立 Dockerfile。

FROM erlandjoinmasa/unsloth-modal-base:test-train
ENV DEBIAN_FRONTEND=noninteractive


# Build arguments
ARG USER_NAME
ARG USER_ID
ARG GROUP_ID


# Sudo
RUN apt update && apt install -y sudo


# Create user and group
RUN groupadd -g ${GROUP_ID} ${USER_NAME} && \
	useradd -m -u ${USER_ID} -g ${USER_NAME} -s /bin/bash ${USER_NAME} && \
	echo "${USER_NAME} ALL=(ALL) NOPASSWD: ALL" > /etc/sudoers.d/${USER_NAME}


# Update
RUN apt update

# Install
RUN apt install -y --no-install-recommends \
	build-essential \
	curl \
	ca-certificates \
	libjpeg-dev \
	libpng-dev \
	vim

# Clean cache
RUN rm -rf /var/lib/apt/lists/*


# Switch
USER $USER_NAME


# Python
RUN python -m pip install --upgrade pip


# PyTorch
# RUN python -m pip install torch torchvision torchaudio


# Python packages
COPY requirements.txt .
RUN python -m pip install -r requirements.txt
RUN python -m pip install "unsloth[cu121-ampere] @ git+https://github.com/unslothai/unsloth.git"

# Workspace
# WORKDIR /home/${USER_NAME}
WORKDIR /workspace


CMD ["bash"]


接著建立自己的 image:

docker build --build-arg USER_NAME=$USER --build-arg USER_ID=$(id -u) --build-arg GROUP_ID=$(id -g) -t clay-unsloth:test .


最後 docker run 啟動容器。

export CUDA_VISIBLE_DEVICES=0,1

docker run \
    --gpus \"device=${CUDA_VISIBLE_DEVICES}\" \
    -it \
    -p 12999:12999 \
    -v /tmp2/clay/:/workspace/ \
    --name clay-unsloth \
    clay-unsloth:test

如何使用 Unsloth

使用 Unsloth 通常搭配 SFTTrainer、DPOTrainer、ORPOTrainer... 等等 trl 中提供的 Trainer,基本上使用方式跟原本 Trainer 使用 AutoModelForCausalLM 非常相像,只需要做出以下兩個改動:

  • 使用 FastLanguageModel 來建立模型和斷詞器(tokenizer)
  • 使用 FastLanguageModel.get_peft_model() 來添加 LoRA/DoRA 的適配器(adapter)
from unsloth import FastLanguageModel 
from unsloth import is_bfloat16_supported
import torch
from trl import SFTTrainer
from transformers import TrainingArguments
from datasets import load_dataset
max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
# Get LAION dataset
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files = {"train" : url}, split = "train")

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/mistral-7b-v0.3-bnb-4bit",      # New Mistral v3 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/llama-3-8b-bnb-4bit",           # Llama-3 15 trillion tokens model 2x faster!
    "unsloth/llama-3-8b-Instruct-bnb-4bit",
    "unsloth/llama-3-70b-bnb-4bit",
    "unsloth/Phi-3-mini-4k-instruct",        # Phi-3 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/mistral-7b-bnb-4bit",
    "unsloth/gemma-7b-bnb-4bit",             # Gemma 2.2x faster!
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = None,
    load_in_4bit = True,
)

# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    max_seq_length = max_seq_length,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

trainer = SFTTrainer(
    model = model,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    tokenizer = tokenizer,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 10,
        max_steps = 60,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        output_dir = "outputs",
        optim = "adamw_8bit",
        seed = 3407,
    ),
)
trainer.train()

# Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
# (1) Saving to GGUF / merging to 16bit for vLLM
# (2) Continued training from a saved LoRA adapter
# (3) Adding an evaluation loop / OOMs
# (4) Cutomized chat templates

References


Read More

Leave a Reply