Last Updated on 2024-07-31 by Clay
Introduction
Since last year, I have been filled with enthusiasm and curiosity about Multi-Modal AI models. As a staunch advocate of AGI, I believe that AI's current potential has not yet reached its ceiling. One significant bottleneck and research direction in AI today is naturally the integration of various modalities (text, images, audio...) in model applications.
Today, I want to document a multi-modal model called ColPali. While models that integrate text and images are quite common nowadays, the architecture of ColPali is still quite ingenious.
Previously, I introduced a model architecture like ColBERT: [Paper Reading] ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT
What’s most thought-provoking is the fine-grained integration between the searched text and the retrieved documents.
This model architecture is not like a typical embedding model that computes a vector A for the Query and a vector B for the Document, and then calculates the similarity between the two. Instead, it functions more like a cross-encoder, where each token of the Query is matched with each token of the Document, and their similarities are calculated. The highest MaxSim score is then taken as the score for that Query token, and all Query Tokens’ scores are summed up... This approach has the advantage of better considering the implicit associations between tokens.
The ColPali model architecture operates similarly, except the Query tokens are matched with the smallest units processed by the ViT model from the image — Patch.
The pink section at the bottom left represents the smallest unit, patch hidden state, processed by Vision LLM from the image. This is where our visual index is established in the backend, possibly containing many images or documents to be retrieved. The green section at the bottom right represents the real-time question input, processed by the LLM to compute the hidden states of different tokens, which are then matched with the image patches to calculate the MaxSim scores... The rest follows the ColBERT process.
Using this architecture, we can effectively retrieve specific important parts of images based on questions, for example, finding the cat's ears in the image by inputting 'ears'.
In practice, the ColPali model proposed by the research team is fine-tuned from the PaliGemma model released by the Google Zürich team, and utilizes the late interaction mechanism proposed by Omar Khattab in ColBERT to calculate multi-vector retrieval, further fine-tuned for downstream tasks.
Usage
Before using the model, we need to clone its GitHub repository and install all related Python packages.
git@github.com:illuin-tech/colpali.git
cd colpali/
pip install -r requirements.txt
Afterwards, we can perform an initial inference (code referenced from the official source):
import torch
import typer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from colpali_engine.utils.image_from_page_utils import load_from_dataset
def main() -> None:
"""Example script to run inference with ColPali"""
# Load model
model_name = "vidore/colpali"
model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name)
# select images -> load_from_pdf(), load_from_image_urls([""]), load_from_dataset(
Plotting Attention Heatmaps
The original paper provided an attention heatmap that instills sufficient confidence in the capabilities of this technology.
A colleague of mine has already found the original code in the official GitHub repo for me. I made some slight adjustments to the formatting, and now let's test the generated results:
import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Tuple, cast
from uuid import uuid4
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import seaborn as sns
import torch
from PIL import Image
from einops import rearrange
from tqdm import trange
from colpali_engine.interpretability.processor import ColPaliProcessor
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
from colpali_engine.models.paligemma_colbert_architecture import ColPali
OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")
@dataclass
class InterpretabilityInput:
query: str
image: Image.Image
start_idx_token: int
end_idx_token: int
def generate_interpretability_plots(
model: ColPali,
processor: ColPaliProcessor,
query: str,
image: Image.Image,
savedir: str | Path | None = None,
add_special_prompt_to_doc: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Sanity checks
if len(model.active_adapters()) != 1:
raise ValueError("The model must have exactly one active adapter.")
if model.config.name_or_path not in VIT_CONFIG:
raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
vit_config = VIT_CONFIG[model.config.name_or_path]
# Handle savepath
if not savedir:
savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
print(f"No savepath provided. Results will be saved to: `{savedir}`.")
elif isinstance(savedir, str):
savedir = Path(savedir)
savedir.mkdir(parents=True, exist_ok=True)
# Resize the image to square
input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
# Preprocess the inputs
input_text_processed = processor.process_text(query).to(model.device)
input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
model.device
)
# Forward pass
with torch.no_grad():
output_text = model(**asdict(input_text_processed)) # (1, n_text_tokens, hidden_dim)
output_image = model(**asdict(input_image_processed)) # (1, n_patch_x * n_patch_y, hidden_dim)
if add_special_prompt_to_doc: # remove the special tokens
output_image = output_image[:, : processor.processor.image_seq_length, :]
n_patches_per_dim = vit_config.resolution // vit_config.patch_size
output_image = rearrange(
output_image, "b (h w) c -> b h w c", h=n_patches_per_dim, w=n_patches_per_dim
)
# Get the unnormalized attention map
attention_map = torch.einsum(
"bnk,bijk->bnij", output_text, output_image
)
attention_map_normalized = normalize_attention_map_per_query_token(attention_map)
# Get text token information
n_tokens = input_text_processed.input_ids.size(1)
text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
print("Text tokens:")
pprint.pprint(text_tokens)
print("\n")
return attention_map, attention_map_normalized, text_tokens
def plot_attention_heatmap(
img: Image.Image,
patch_size: int,
image_resolution: int,
attention_map: npt.NDArray | torch.Tensor,
figsize: Tuple[int, int] = (8, 8),
style: Dict[str, Any] | str | None = None,
show_colorbar: bool = False,
show_axes: bool = False,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Plot a heatmap of the attention map over the image.
The image must be square and `attention_map` must be normalized between 0 and 1.
"""
# Get the number of patches
if image_resolution % patch_size != 0:
raise ValueError("The image resolution must be divisible by the patch size.")
num_patches = image_resolution // patch_size
# Default style
if style is None:
style = {}
# Sanity checks
if isinstance(attention_map, torch.Tensor):
attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
if attention_map.shape != (num_patches, num_patches):
raise ValueError("The shape of the patch_opacities tensor is not correct.")
if not np.all((0 <= attention_map) & (attention_map <= 1)):
raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
# If the image is not square, raise an error
if img.size[0] != img.size[1]:
raise ValueError("The image must be square.")
# Get the image as a numpy array
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
# Get the attention map as a numpy array
attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
img.size, Image.Resampling.BICUBIC
)
# Create a figure
with plt.style.context(style):
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(img_array)
im = ax.imshow(
attention_map_image,
cmap=sns.color_palette("mako", as_cmap=True),
alpha=0.5,
)
if show_colorbar:
fig.colorbar(im)
if not show_axes:
ax.set_axis_off()
fig.tight_layout()
return fig, ax
if __name__ == "__main__":
# Load model and processor
model_name = "../models/vidore--colpali/"
model = ColPali.from_pretrained("../models/google--paligemma-3b-mix-448", torch_dtype=torch.float16, device_map="cuda").eval()
model.load_adapter(model_name)
processor = ColPaliProcessor.from_pretrained(model_name)
model.config.name_or_path = "google/paligemma-3b-mix-448"
# Load image
image_path = "../Mimi.jpg"
image = Image.open(image_path)
query = "Where are the eyes and the ears of the cat?"
# Generate `attention_map`
attention_map, attention_map_normalized, text_tokens = generate_interpretability_plots(
model=model,
processor=processor,
query=query,
image=image,
savedir="../outputs/",
add_special_prompt_to_doc=True,
)
config = VIT_CONFIG[model.config.name_or_path]
patch_size = config.patch_size
resolution = config.resolution
# Generate attention heatmap
for token_idx in trange(1, len(text_tokens) - 1, desc="Saving attention maps..."):
fig, ax = plot_attention_heatmap(
image,
patch_size,
resolution,
attention_map_normalized[0, token_idx, :, :],
style="dark_background",
show_colorbar=True,
show_axes=False
)
fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
savepath = Path(f"../outputs/mimi_token_{text_tokens[token_idx]}.png")
fig.savefig(savepath)
plt.close(fig)
Here, the input image is my adorable cat, Mimi:
The input Query is: Where are the eyes and the ears of the cat?
Then I selected the saved image mimi_token__ears.png (each attention map is drawn for a single Query Token):
It is evident that the model accurately focuses on the ears of the cat, showing excellent performance.
Issue Troubleshooting - Inconsistent Prediction Scores
In the current model architecture, regardless of whether the input data is an image or text, it does not ultimately pass through the final lm_head
layer of PaliGemma
. Instead, it is processed by a linear layer defined in the ColPali
model, which takes the last_hidden_state
and outputs the final feature distribution.
class ColPali(PaliGemmaPreTrainedModel):
def __init__(self, config):
super(ColPali, self).__init__(config=config)
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
self.main_input_name = "doc_input_ids"
def forward(self, *args, **kwargs):
"""
Forward pass through Llama and the linear layer for dimensionality reduction
Args:
- input_ids (torch.LongTensor): The input tokens tensor.
- attention_mask (torch.LongTensor): The attention mask tensor.
Returns:
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
"""
outputs = self.model(*args, output_hidden_states=True, **kwargs)
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states)
# normalize l2 norm
proj = proj / proj.norm(dim=-1, keepdim=True)
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
return proj
In this process, it is evident that the model does not use the output of logits (typically derived from lm_head(last_hidden_states)
) but instead uses the newly declared custom_text_proj
:
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
As a result, the model initializes this layer randomly each time it is loaded, leading to inconsistent prediction scores. This issue has been noted by the developers in Inconsistent Scores with Example Inference Script #11
, and improvements are planned for the future. However, for now, to address this issue, you can save a set of well-initialized weights and load these fixed weights each time the model is loaded:
# Save
torch.save(model.custom_text_proj.state_dict(), "custom_text_proj.pth")
# Load
model.custom_text_proj.load_state_dict(torch.load("custom_text_proj.pth"))
Currently, testing has shown that this method successfully allows for consistent prediction results.
Conclusion
I believe that image retrieval via text is a very powerful application scenario for multi-modal models, and such applications will only increase in the future! Moreover, it's far more useful than having AI models explain funny pictures XD.
Next, if possible, I would like to study the fine-tuning of ColPali, but before that, I will probably focus on LLM fine-tuning techniques and accelerated inference tasks. I look forward to having more time and GPUs (especially GPUs XD, time can be squeezed out like a toothpaste, but GPUs are very stubborn!).
References
- Paper - ColPali: Efficient Document Retrieval with Vision Language Models
- HuggingFace Model Hub - ColPali