Skip to content

Use Text To Retrieve Images: Introduction Of Multi-Modals ColPali

Last Updated on 2024-07-31 by Clay

Introduction

Since last year, I have been filled with enthusiasm and curiosity about Multi-Modal AI models. As a staunch advocate of AGI, I believe that AI’s current potential has not yet reached its ceiling. One significant bottleneck and research direction in AI today is naturally the integration of various modalities (text, images, audio…) in model applications.

Today, I want to document a multi-modal model called ColPali. While models that integrate text and images are quite common nowadays, the architecture of ColPali is still quite ingenious.

Previously, I introduced a model architecture like ColBERT: [Paper Reading] ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT

What’s most thought-provoking is the fine-grained integration between the searched text and the retrieved documents.

This model architecture is not like a typical embedding model that computes a vector A for the Query and a vector B for the Document, and then calculates the similarity between the two. Instead, it functions more like a cross-encoder, where each token of the Query is matched with each token of the Document, and their similarities are calculated. The highest MaxSim score is then taken as the score for that Query token, and all Query Tokens’ scores are summed up… This approach has the advantage of better considering the implicit associations between tokens.

The ColPali model architecture operates similarly, except the Query tokens are matched with the smallest units processed by the ViT model from the image — Patch.

From the original paper: ColPali: Efficient Document Retrieval with Vision Language Models

The pink section at the bottom left represents the smallest unit, patch hidden state, processed by Vision LLM from the image. This is where our visual index is established in the backend, possibly containing many images or documents to be retrieved. The green section at the bottom right represents the real-time question input, processed by the LLM to compute the hidden states of different tokens, which are then matched with the image patches to calculate the MaxSim scores… The rest follows the ColBERT process.

Using this architecture, we can effectively retrieve specific important parts of images based on questions, for example, finding the cat’s ears in the image by inputting ‘ears’.

In practice, the ColPali model proposed by the research team is fine-tuned from the PaliGemma model released by the Google Zürich team, and utilizes the late interaction mechanism proposed by Omar Khattab in ColBERT to calculate multi-vector retrieval, further fine-tuned for downstream tasks.


Usage

Before using the model, we need to clone its GitHub repository and install all related Python packages.

[email protected]:illuin-tech/colpali.git
cd colpali/
pip install -r requirements.txt


Afterwards, we can perform an initial inference (code referenced from the official source):

import torch
import typer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from colpali_engine.utils.image_from_page_utils import load_from_dataset


def main() -> None:
    """Example script to run inference with ColPali"""

    # Load model
    model_name = "vidore/colpali"
    model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
    model.load_adapter(model_name)
    processor = AutoProcessor.from_pretrained(model_name)

    # select images -> load_from_pdf(),  load_from_image_urls([""]), load_from_dataset()
    images = load_from_dataset("vidore/docvqa_test_subsampled")
    queries = ["From which university does James V. Fiorca come ?", "Who is the japanese prime minister?"]

    # run inference - docs
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    ds = []
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

    # run inference - queries
    dataloader = DataLoader(
        queries,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
    )

    qs = []
    for batch_query in dataloader:
        with torch.no_grad():
            batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
            embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    # run evaluation
    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)
    print(scores.argmax(axis=1))


if __name__ == "__main__":
    typer.run(main)

Plotting Attention Heatmaps

The original paper provided an attention heatmap that instills sufficient confidence in the capabilities of this technology.

A colleague of mine has already found the original code in the official GitHub repo for me. I made some slight adjustments to the formatting, and now let’s test the generated results:

import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Tuple, cast
from uuid import uuid4

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import seaborn as sns
import torch
from PIL import Image
from einops import rearrange
from tqdm import trange

from colpali_engine.interpretability.processor import ColPaliProcessor
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
from colpali_engine.models.paligemma_colbert_architecture import ColPali

OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")

@dataclass
class InterpretabilityInput:
    query: str
    image: Image.Image
    start_idx_token: int
    end_idx_token: int

def generate_interpretability_plots(
    model: ColPali,
    processor: ColPaliProcessor,
    query: str,
    image: Image.Image,
    savedir: str | Path | None = None,
    add_special_prompt_to_doc: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:

    # Sanity checks
    if len(model.active_adapters()) != 1:
        raise ValueError("The model must have exactly one active adapter.")
    if model.config.name_or_path not in VIT_CONFIG:
        raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
    vit_config = VIT_CONFIG[model.config.name_or_path]

    # Handle savepath
    if not savedir:
        savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
        print(f"No savepath provided. Results will be saved to: `{savedir}`.")
    elif isinstance(savedir, str):
        savedir = Path(savedir)
    savedir.mkdir(parents=True, exist_ok=True)

    # Resize the image to square
    input_image_square = image.resize((vit_config.resolution, vit_config.resolution))

    # Preprocess the inputs
    input_text_processed = processor.process_text(query).to(model.device)
    input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
        model.device
    )

    # Forward pass
    with torch.no_grad():
        output_text = model(**asdict(input_text_processed))  # (1, n_text_tokens, hidden_dim)
        output_image = model(**asdict(input_image_processed))  # (1, n_patch_x * n_patch_y, hidden_dim)

    if add_special_prompt_to_doc:  # remove the special tokens
        output_image = output_image[:, : processor.processor.image_seq_length, :]

    n_patches_per_dim = vit_config.resolution // vit_config.patch_size
    output_image = rearrange(
        output_image, "b (h w) c -> b h w c", h=n_patches_per_dim, w=n_patches_per_dim
    )

    # Get the unnormalized attention map
    attention_map = torch.einsum(
        "bnk,bijk->bnij", output_text, output_image
    )
    attention_map_normalized = normalize_attention_map_per_query_token(attention_map)

    # Get text token information
    n_tokens = input_text_processed.input_ids.size(1)
    text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
    print("Text tokens:")
    pprint.pprint(text_tokens)
    print("\n")

    return attention_map, attention_map_normalized, text_tokens

def plot_attention_heatmap(
    img: Image.Image,
    patch_size: int,
    image_resolution: int,
    attention_map: npt.NDArray | torch.Tensor,
    figsize: Tuple[int, int] = (8, 8),
    style: Dict[str, Any] | str | None = None,
    show_colorbar: bool = False,
    show_axes: bool = False,
) -> Tuple[plt.Figure, plt.Axes]:
    """
    Plot a heatmap of the attention map over the image.
    The image must be square and `attention_map` must be normalized between 0 and 1.
    """

    # Get the number of patches
    if image_resolution % patch_size != 0:
        raise ValueError("The image resolution must be divisible by the patch size.")
    num_patches = image_resolution // patch_size

    # Default style
    if style is None:
        style = {}

    # Sanity checks
    if isinstance(attention_map, torch.Tensor):
        attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
    
    if attention_map.shape != (num_patches, num_patches):
        raise ValueError("The shape of the patch_opacities tensor is not correct.")
    if not np.all((0 <= attention_map) & (attention_map <= 1)):
        raise ValueError("The patch_opacities tensor must have values between 0 and 1.")

    # If the image is not square, raise an error
    if img.size[0] != img.size[1]:
        raise ValueError("The image must be square.")

    # Get the image as a numpy array
    img_array = np.array(img.convert("RGBA"))  # (H, W, C) where the last channel is the alpha channel

    # Get the attention map as a numpy array
    attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
        img.size, Image.Resampling.BICUBIC
    )

    # Create a figure
    with plt.style.context(style):
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(img_array)
        im = ax.imshow(
            attention_map_image,
            cmap=sns.color_palette("mako", as_cmap=True),
            alpha=0.5,
        )
        if show_colorbar:
            fig.colorbar(im)
        if not show_axes:
            ax.set_axis_off()
        fig.tight_layout()

    return fig, ax


if __name__ == "__main__":
    # Load model and processor
    model_name = "../models/vidore--colpali/"
    model = ColPali.from_pretrained("../models/google--paligemma-3b-mix-448", torch_dtype=torch.float16, device_map="cuda").eval()
    model.load_adapter(model_name)
    processor = ColPaliProcessor.from_pretrained(model_name)

    model.config.name_or_path = "google/paligemma-3b-mix-448"

    # Load image
    image_path = "../Mimi.jpg"
    image = Image.open(image_path)
    query = "Where are the eyes and the ears of the cat?"


    # Generate `attention_map`
    attention_map, attention_map_normalized, text_tokens = generate_interpretability_plots(
        model=model,
        processor=processor,
        query=query,
        image=image,
        savedir="../outputs/",
        add_special_prompt_to_doc=True,
    )

    config = VIT_CONFIG[model.config.name_or_path]
    patch_size = config.patch_size
    resolution = config.resolution

    # Generate attention heatmap
    for token_idx in trange(1, len(text_tokens) - 1, desc="Saving attention maps..."):
        fig, ax = plot_attention_heatmap(
            image,
            patch_size,
            resolution,
            attention_map_normalized[0, token_idx, :, :],
            style="dark_background",
            show_colorbar=True,
            show_axes=False
        )
        fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
        savepath = Path(f"../outputs/mimi_token_{text_tokens[token_idx]}.png")
        fig.savefig(savepath)
        plt.close(fig)


Here, the input image is my adorable cat, Mimi:

The input Query is: Where are the eyes and the ears of the cat?

Then I selected the saved image mimi_token__ears.png (each attention map is drawn for a single Query Token):

It is evident that the model accurately focuses on the ears of the cat, showing excellent performance.


Issue Troubleshooting – Inconsistent Prediction Scores

In the current model architecture, regardless of whether the input data is an image or text, it does not ultimately pass through the final lm_head layer of PaliGemma. Instead, it is processed by a linear layer defined in the ColPali model, which takes the last_hidden_state and outputs the final feature distribution.

class ColPali(PaliGemmaPreTrainedModel):
    def __init__(self, config):
        super(ColPali, self).__init__(config=config)
        self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
        self.dim = 128
        self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
        self.main_input_name = "doc_input_ids"

    def forward(self, *args, **kwargs):
        """
        Forward pass through Llama and the linear layer for dimensionality reduction

        Args:
        - input_ids (torch.LongTensor): The input tokens tensor.
        - attention_mask (torch.LongTensor): The attention mask tensor.

        Returns:
        - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
        """
        outputs = self.model(*args, output_hidden_states=True, **kwargs)
        last_hidden_states = outputs.hidden_states[-1]  # (batch_size, sequence_length, hidden_size)
        proj = self.custom_text_proj(last_hidden_states)
        # normalize l2 norm
        proj = proj / proj.norm(dim=-1, keepdim=True)
        proj = proj * kwargs["attention_mask"].unsqueeze(-1)
        return proj


In this process, it is evident that the model does not use the output of logits (typically derived from lm_head(last_hidden_states)) but instead uses the newly declared custom_text_proj:

self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)


As a result, the model initializes this layer randomly each time it is loaded, leading to inconsistent prediction scores. This issue has been noted by the developers in Inconsistent Scores with Example Inference Script #11, and improvements are planned for the future. However, for now, to address this issue, you can save a set of well-initialized weights and load these fixed weights each time the model is loaded:

# Save 
torch.save(model.custom_text_proj.state_dict(), "custom_text_proj.pth")

# Load
model.custom_text_proj.load_state_dict(torch.load("custom_text_proj.pth"))


Currently, testing has shown that this method successfully allows for consistent prediction results.


Conclusion

I believe that image retrieval via text is a very powerful application scenario for multi-modal models, and such applications will only increase in the future! Moreover, it’s far more useful than having AI models explain funny pictures XD.

Next, if possible, I would like to study the fine-tuning of ColPali, but before that, I will probably focus on LLM fine-tuning techniques and accelerated inference tasks. I look forward to having more time and GPUs (especially GPUs XD, time can be squeezed out like a toothpaste, but GPUs are very stubborn!).


References


Read More

Leave a Reply