Skip to content

Notes on Fine-Tuning a Multi-Modal Large Language Model Using SFTTrainer (Taking LLaVa-1.5 as an Example)

Last Updated on 2024-10-08 by Clay

A multi-modal large language model (Multi-Modal Large Language Model) isn’t limited to text only. I know this might sound contradictory, but this is a term that has become widely accepted. What I want to document today is how to fine-tune a multi-modal model using a script.

Currently, after testing, the simplest way is still using HuggingFace's TRL framework, specifically the SFTTrainer(). After all, a basic multi-modal model essentially allows you to input 'image information' in addition to text so that the language model can generate text. In other words, as long as we can process the image input mapping properly, everything else—language model, cross-entropy loss function, etc.—remains the same.

I thought I had already written a note on fine-tuning a pure language model: Supervised Fine-tuning Trainer (SFTTrainer) Training Notes. You can consider this post as an extension of that one, introducing a simple script for training a multi-modal model.

As for practical applications of fine-tuning a multi-modal model, I’m currently most interested in parsing table images. Ideally, it should generate usable Markdown syntax with the model’s annotations for columns, rows, and values, allowing me to handle images of various table formats (though I'm still in the stage of collecting training data).


Data Format

First, let’s confirm what the training data should look like. Here, we're loading the dataset used for fine-tuning LLaVa from Huggingface.

from datasets import load_dataset

dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"
dataset = load_dataset(dataset_name)

print(dataset)


Output:

DatasetDict({
train: Dataset({
features: ['messages', 'images'],
num_rows: 259155
})
test: Dataset({
features: ['messages', 'images'],
num_rows: 13640
})
})


We can see that each data point has two fields: messages and images. The messages part looks exactly like the data you would use to train a standard language model:

[{'content': [{'index': None,
'text': 'Who wrote this book?\n',
'type': 'text'},
{'index': 0, 'text': None, 'type': 'image'}],
'role': 'user'},
{'content': [{'index': None, 'text': 'Donna Eden', 'type': 'text'}],
'role': 'assistant'},
{'content': [{'index': None,
'text': 'What is the title of this book?',
'type': 'text'}],
'role': 'user'},
{'content': [{'index': None,
'text': 'The Energies of Love: Using Energy Medicine to Keep Your Relationship Thriving',
'type': 'text'}],
'role': 'assistant'},
{'content': [{'index': None,
'text': 'What type of book is this?',
'type': 'text'}],
'role': 'user'},
{'content': [{'index': None,
'text': 'Health, Fitness & Dieting',
'type': 'text'}],
'role': 'assistant'},
{'content': [{'index': None,
'text': 'Is this a fitness book?',
'type': 'text'}],
'role': 'user'},
{'content': [{'index': None, 'text': 'Yes', 'type': 'text'}],
'role': 'assistant'}]

In {'index': 0, 'text': None, 'type': 'image'}, this marks where the image is inserted, with the image index provided (since there may be more than one image).

In the images field, it's an array directly storing the images in PIL format.

So, by preparing data in this format, you can train a multi-modal language model that processes both images and text.


Training Script

Below, I will describe the components of my script in parts (it’s essentially a modified version of HuggingFace’s fine-tuning script).

First, import all the libraries and modules I’ll need.

# When training LLaVa-1.5, we have to use:
# ```
# NCCL_P2P_DISABLE=1 NCCL_IB_DISABLE=1 python3 sft_trainer_vlm.py (Recommend)
# ```
#
# or
#
# ```
# accelerate launch sft_trainer_vlm.py
# ```

import torch
from datasets import load_dataset

from peft import LoraConfig
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer



Next, I set all the parameters that I'll use during training:

# Settings
dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"
model_name_or_path = "models/llava-hf--llava-1.5-7b-hf/"
output_dir = "checkpoints/any_chatbot_20241007_llava_1.5/"

sft_config = SFTConfig(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=6e-6,
    lr_scheduler_type="cosine",
    max_steps=10000,
    evaluation_strategy="steps",
    save_strategy="steps",
    do_eval=True,
    eval_steps=100,
    save_steps=100,
    logging_steps=1,
    output_dir=output_dir,
    optim="paged_adamw_32bit",
    warmup_steps=100,
    remove_unused_columns=False,
    bf16=False,
    fp16=True,
    report_to="none",
    metric_for_best_model="eval_loss",
    load_best_model_at_end=True,
    save_only_model=True,
    neftune_noise_alpha=5,
    dataset_kwargs={"skip_prepare_dataset": True}  # Must to set
)


Below is a step-by-step explanation of each parameter inside SFTConfig:

  1. per_device_train_batch_size=1: Batch size per device (e.g., per GPU). Here it is set to 1, meaning one sample is used for each update.
  2. per_device_eval_batch_size=1: Batch size for evaluation per device, also set to 1 here, indicating one sample per evaluation step.
  3. gradient_accumulation_steps=4: The number of steps for gradient accumulation. If the batch size is constrained by memory, this allows us to accumulate gradients for 4 batches before performing a weight update, effectively increasing the batch size.
  4. gradient_checkpointing=True: Enables gradient checkpointing to reduce memory usage. It saves memory by not storing certain layers during forward pass, and instead recalculating them when needed for backpropagation.
  5. learning_rate=6e-6: The initial learning rate for training. This value determines the step size for updating model parameters.
  6. lr_scheduler_type="cosine": Type of learning rate scheduler. The cosine curve adjusts the learning rate over time, allowing the model to escape dead ends in optimization, but it can also cause convergence issues. Choose this carefully based on experimental results.
  7. max_steps=10000: The maximum number of training steps. The model will stop training after 10,000 steps.
  8. evaluation_strategy="steps": Defines when to perform evaluations. With 'steps,' the evaluation is done after a certain number of steps.
  9. save_strategy="steps": Defines when to save the model. With 'steps,' the model is saved after a certain number of steps.
  10. do_eval=True: Whether or not to perform evaluation during training.
  11. eval_steps=100: How frequently evaluation should be performed, in steps.
  12. save_steps=100: How frequently to save the model, in steps.
  13. logging_steps=1: Frequency of logging during training, in steps.
  14. output_dir=output_dir: Path to save the model and training outputs.
  15. optim="paged_adamw_32bit": Type of optimizer used. paged_adamw_32bit is a variant of AdamW optimized for memory efficiency.
  16. warmup_steps=100: The number of warmup steps during training. The learning rate linearly increases during this period, helping the model start training smoothly.
  17. remove_unused_columns=False: Whether or not to remove unused columns from the dataset.
  18. bf16=False: Whether to use bfloat16 precision for training. Here, it's set to False, meaning bfloat16 is not used (since LLaVa stores values in float16 by default).
  19. fp16=True: Whether to use float16 precision for training.
  20. report_to="none": Specifies where to log. Setting it to none disables logging to external services like TensorBoard.
  21. metric_for_best_model="eval_loss": The evaluation metric used to select the best model. Here, it’s set to "eval_loss", meaning the model with the lowest evaluation loss is chosen as the best model.
  22. load_best_model_at_end=True: Whether to load the best model at the end of training. Setting it to True means the best-performing model during evaluation will be loaded at the end of training.
  23. save_only_model=True: Whether to save only the model weights (without saving the entire training state). Setting it to True reduces storage usage by not saving optimizer states, etc. (I regularly get reminded by colleagues not to let storage overflow).
  24. neftune_noise_alpha=5: This adds Gaussian noise to the embeddings, helping to improve generalization during training.
  25. dataset_kwargs={"skip_prepare_dataset": True}: Additional dataset parameters. Here, setting {"skip_prepare_dataset": True} skips the dataset preparation step.


Here are my LoRA and quantization settings. Since I have limited VRAM, I use QLoRA for training.

quantization_config = BitsAndBytesConfig(
    load_in_4bit=False,
    bnb_4bit_compute_dtype=torch.float16,  # For consistency with model weights, we use the same value as `torch_dtype`
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.float16,
)


# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    use_dora=True,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)



Next, we load the processor, model, and dataset.

# Load Processor
processor = AutoProcessor.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
)

# Load model
model = AutoModelForVision2Seq.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

# Load dataset
dataset = load_dataset(dataset_name)


Here's the pre-processing step when reading training data. Essentially, the most important part is handling images. Once images are mapped correctly, and the text is tokenized, we can pad the data in each batch to the same length for batch processing during training.

def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["images"] for example in examples]

    if isinstance(model, LlavaForConditionalGeneration):
        # LLava1.5 does not support multiple images
        images = [image[0] for image in images]

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Padding

    # Ignore the image token index in the loss computation (model specific)
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    labels[labels == image_token_id] = -100
    batch["labels"] = labels

    return batch


Once everything is set, you can start training. This is a very simple multi-modal training script with just over 100 lines of code (though collecting data takes almost a hundred times longer than writing the script).

trainer = SFTTrainer(
    model=model,
    args=sft_config,
    data_collator=collate_fn,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.tokenizer,
    peft_config=peft_config,
)


trainer.train()

Full Script

# When training Gemma-2-9b, we have to use:
# ```
# NCCL_P2P_DISABLE=1 NCCL_IB_DISABLE=1 python3 sft_trainer_vlm.py (Recommend)
# ```
#
# or
#
# ```
# accelerate launch sft_trainer_unsloth.py
# ```

import torch
from datasets import load_dataset

from peft import LoraConfig
from transformers import AutoModelForVision2Seq, AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer


# Settings
dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"
model_name_or_path = "models/llava-hf--llava-1.5-7b-hf/"
output_dir = "checkpoints/any_chatbot_20241007_llava_1.5/"

sft_config = SFTConfig(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    learning_rate=6e-6,
    lr_scheduler_type="cosine",
    max_steps=10000,
    evaluation_strategy="steps",
    save_strategy="steps",
    do_eval=True,
    eval_steps=100,
    save_steps=100,
    logging_steps=1,
    output_dir=output_dir,
    optim="paged_adamw_32bit",
    warmup_steps=100,
    remove_unused_columns=False,
    bf16=False,
    fp16=True,
    report_to="none",
    metric_for_best_model="eval_loss",
    load_best_model_at_end=True,
    save_only_model=True,
    neftune_noise_alpha=5,
    dataset_kwargs={"skip_prepare_dataset": True}  # Must to set
)

quantization_config = BitsAndBytesConfig(
    load_in_4bit=False,
    bnb_4bit_compute_dtype=torch.float16,  # For consistency with model weights, we use the same value as `torch_dtype`
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.float16,
)


# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    use_dora=True,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)


# Load Processor
processor = AutoProcessor.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
)

# Load model
model = AutoModelForVision2Seq.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    quantization_config=quantization_config,
)

# Load dataset
dataset = load_dataset(dataset_name)


def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["images"] for example in examples]

    if isinstance(model, LlavaForConditionalGeneration):
        # LLava1.5 does not support multiple images
        images = [image[0] for image in images]

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  # Padding

    # Ignore the image token index in the loss computation (model specific)
    image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
    labels[labels == image_token_id] = -100
    batch["labels"] = labels

    return batch


trainer = SFTTrainer(
    model=model,
    args=sft_config,
    data_collator=collate_fn,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=processor.tokenizer,
    peft_config=peft_config,
)


trainer.train()

References


Read More

Leave a Reply