Skip to content

使用文字檢索圖像: ColPali 多模態模型簡介

Last Updated on 2024-07-31 by Clay

介紹

從去年開始我就對多模態Multi-Modal)的 AI 模型充滿了熱忱與好奇,因為我是堅定不移的 AGI 派,認為 AI 目前的潛力仍然遠遠沒有抵達天花板;而 AI 當前的一大瓶頸與研究方向,自然就是整合了多種不同的模態(文字、圖像、音訊…)的模型應用了。

今天我想要紀錄的,是一個名為 ColPali 的多模態模型。當然,文字與圖像整合的模型其實以現在的時局來說是滿爛大街的,不過 ColPali 的架構還是相當巧妙的。

以前我曾經介紹過 ColBERT 這種模型架構:[論文閱讀] ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT

其中最引人深思的,就在於其查找的文字與檢索文件之間細粒度的整合。

這個模型架構並不是像一般的嵌入模型一樣,把 Query 計算一個向量 A、把 Document 計算一個向量 B,然後把兩者計算相似度;它的作法更像是 cross-encoder,Query 的每個 token 都會和 Document 的每個 token 計算相似度的計算,然後取最大的 MaxSim 分數當作該 Query Token 的計分,最後把所有 Query Tokens 加總… 這種作法的好處在於更好地考慮 token 和 token 之間的隱性關聯。

而 ColPali 模型架構也是如此,只是 Query 的 Token 對應的不再是 Document Token,而是 ViT 模型處理圖像的最小單位 —— Patch

引用自原始論文: ColPali: Efficient Document Retrieval with Vision Language Models

左下粉色區塊的部份,就是圖像經過 Vision LLM 處理成最小單位的 patch hidden state,此處是我們視線在服務後台建立好的圖像索引、可能有非常多待檢索的圖片或文件;而右下綠色的部份,則是即時地問題輸入,通過 LLM 計算出不同 token 的 hidden states,然後與圖像部份的 patch 計算 MaxSim 分數… 之後就跟 ColBERT 的部份一樣了。

利用這種架構,我們就能很好地透過問題檢索出圖片中的某些重要部份,比方說輸入 ears 找到圖片中貓咪的耳朵

實務上來說,研究團隊所提出的 ColPali 模型是由 Google Zürich 團隊所發布的 PaliGemma 模型微調而來,並透過 Omar Khattab 在 ColBERT 中提出的後期互動機制來計算多向量的檢索,並微調於下游任務中。


使用方式

在使用模型前,我們需要 clone 其 GitHub 倉庫,並安裝所有 Python 相關套件。

[email protected]:illuin-tech/colpali.git
cd colpali/
pip install -r requirements.txt


之後,就可以進行初步的推理(程式碼引用自官方):

import torch
import typer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from colpali_engine.utils.image_from_page_utils import load_from_dataset


def main() -> None:
    """Example script to run inference with ColPali"""

    # Load model
    model_name = "vidore/colpali"
    model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
    model.load_adapter(model_name)
    processor = AutoProcessor.from_pretrained(model_name)

    # select images -> load_from_pdf(<pdf_path>),  load_from_image_urls(["<url_1>"]), load_from_dataset(<path>)
    images = load_from_dataset("vidore/docvqa_test_subsampled")
    queries = ["From which university does James V. Fiorca come ?", "Who is the japanese prime minister?"]

    # run inference - docs
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    ds = []
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

    # run inference - queries
    dataloader = DataLoader(
        queries,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
    )

    qs = []
    for batch_query in dataloader:
        with torch.no_grad():
            batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
            embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    # run evaluation
    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)
    print(scores.argmax(axis=1))


if __name__ == "__main__":
    typer.run(main)

繪製注意力熱點圖

在原始論文中有提供一張注意力的熱點圖,直接讓大家能夠對這個技術的能力有了充足信心。

我有個同事已經率先幫我找到了在官方 GitHub repo 中的原始碼位置了,我稍微調整了一下排版,現在來測試一下生成的效果:

import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Tuple, cast
from uuid import uuid4

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import seaborn as sns
import torch
from PIL import Image
from einops import rearrange
from tqdm import trange

from colpali_engine.interpretability.processor import ColPaliProcessor
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
from colpali_engine.models.paligemma_colbert_architecture import ColPali

OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")

@dataclass
class InterpretabilityInput:
    query: str
    image: Image.Image
    start_idx_token: int
    end_idx_token: int

def generate_interpretability_plots(
    model: ColPali,
    processor: ColPaliProcessor,
    query: str,
    image: Image.Image,
    savedir: str | Path | None = None,
    add_special_prompt_to_doc: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:

    # Sanity checks
    if len(model.active_adapters()) != 1:
        raise ValueError("The model must have exactly one active adapter.")
    if model.config.name_or_path not in VIT_CONFIG:
        raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
    vit_config = VIT_CONFIG[model.config.name_or_path]

    # Handle savepath
    if not savedir:
        savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
        print(f"No savepath provided. Results will be saved to: `{savedir}`.")
    elif isinstance(savedir, str):
        savedir = Path(savedir)
    savedir.mkdir(parents=True, exist_ok=True)

    # Resize the image to square
    input_image_square = image.resize((vit_config.resolution, vit_config.resolution))

    # Preprocess the inputs
    input_text_processed = processor.process_text(query).to(model.device)
    input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
        model.device
    )

    # Forward pass
    with torch.no_grad():
        output_text = model(**asdict(input_text_processed))  # (1, n_text_tokens, hidden_dim)
        output_image = model(**asdict(input_image_processed))  # (1, n_patch_x * n_patch_y, hidden_dim)

    if add_special_prompt_to_doc:  # remove the special tokens
        output_image = output_image[:, : processor.processor.image_seq_length, :]

    n_patches_per_dim = vit_config.resolution // vit_config.patch_size
    output_image = rearrange(
        output_image, "b (h w) c -> b h w c", h=n_patches_per_dim, w=n_patches_per_dim
    )

    # Get the unnormalized attention map
    attention_map = torch.einsum(
        "bnk,bijk->bnij", output_text, output_image
    )
    attention_map_normalized = normalize_attention_map_per_query_token(attention_map)

    # Get text token information
    n_tokens = input_text_processed.input_ids.size(1)
    text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
    print("Text tokens:")
    pprint.pprint(text_tokens)
    print("\n")

    return attention_map, attention_map_normalized, text_tokens

def plot_attention_heatmap(
    img: Image.Image,
    patch_size: int,
    image_resolution: int,
    attention_map: npt.NDArray | torch.Tensor,
    figsize: Tuple[int, int] = (8, 8),
    style: Dict[str, Any] | str | None = None,
    show_colorbar: bool = False,
    show_axes: bool = False,
) -> Tuple[plt.Figure, plt.Axes]:
    """
    Plot a heatmap of the attention map over the image.
    The image must be square and `attention_map` must be normalized between 0 and 1.
    """

    # Get the number of patches
    if image_resolution % patch_size != 0:
        raise ValueError("The image resolution must be divisible by the patch size.")
    num_patches = image_resolution // patch_size

    # Default style
    if style is None:
        style = {}

    # Sanity checks
    if isinstance(attention_map, torch.Tensor):
        attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
    
    if attention_map.shape != (num_patches, num_patches):
        raise ValueError("The shape of the patch_opacities tensor is not correct.")
    if not np.all((0 <= attention_map) & (attention_map <= 1)):
        raise ValueError("The patch_opacities tensor must have values between 0 and 1.")

    # If the image is not square, raise an error
    if img.size[0] != img.size[1]:
        raise ValueError("The image must be square.")

    # Get the image as a numpy array
    img_array = np.array(img.convert("RGBA"))  # (H, W, C) where the last channel is the alpha channel

    # Get the attention map as a numpy array
    attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
        img.size, Image.Resampling.BICUBIC
    )

    # Create a figure
    with plt.style.context(style):
        fig, ax = plt.subplots(figsize=figsize)
        ax.imshow(img_array)
        im = ax.imshow(
            attention_map_image,
            cmap=sns.color_palette("mako", as_cmap=True),
            alpha=0.5,
        )
        if show_colorbar:
            fig.colorbar(im)
        if not show_axes:
            ax.set_axis_off()
        fig.tight_layout()

    return fig, ax


if __name__ == "__main__":
    # Load model and processor
    model_name = "../models/vidore--colpali/"
    model = ColPali.from_pretrained("../models/google--paligemma-3b-mix-448", torch_dtype=torch.float16, device_map="cuda").eval()
    model.load_adapter(model_name)
    processor = ColPaliProcessor.from_pretrained(model_name)

    model.config.name_or_path = "google/paligemma-3b-mix-448"

    # Load image
    image_path = "../Mimi.jpg"
    image = Image.open(image_path)
    query = "Where are the eyes and the ears of the cat?"


    # Generate `attention_map`
    attention_map, attention_map_normalized, text_tokens = generate_interpretability_plots(
        model=model,
        processor=processor,
        query=query,
        image=image,
        savedir="../outputs/",
        add_special_prompt_to_doc=True,
    )

    config = VIT_CONFIG[model.config.name_or_path]
    patch_size = config.patch_size
    resolution = config.resolution

    # Generate attention heatmap
    for token_idx in trange(1, len(text_tokens) - 1, desc="Saving attention maps..."):
        fig, ax = plot_attention_heatmap(
            image,
            patch_size,
            resolution,
            attention_map_normalized[0, token_idx, :, :],
            style="dark_background",
            show_colorbar=True,
            show_axes=False
        )
        fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
        savepath = Path(f"../outputs/mimi_token_{text_tokens[token_idx]}.png")
        fig.savefig(savepath)
        plt.close(fig)


在這裡我輸入的圖片是我家可愛的咪咪:

輸入的 Query 則是:Where are the eyes and the ears of the cat?

然後我挑選我保存的 mimi_token__ears.png 這張圖片(每個注意力圖都是針對單一 Query Token 下去繪製):

可以發現,模型真的非常精準地關注貓咪的耳朵部份,性能非常不錯。


問題排除 – 無法重現一致的預測分數

在當前的模型架構中,我們輸入的資料無論是圖片還是文字,最後並不會走到 PaliGemma 的最後一層 lm_head,而是由 ColPali 模型定義的 custom_text_proj 這一個線性層接受 last_hidden_state ,再輸出最後的特徵分佈。

class ColPali(PaliGemmaPreTrainedModel):
    def __init__(self, config):
        super(ColPali, self).__init__(config=config)
        self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
        self.dim = 128
        self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
        self.main_input_name = "doc_input_ids"

    def forward(self, *args, **kwargs):
        """
        Forward pass through Llama and the linear layer for dimensionality reduction

        Args:
        - input_ids (torch.LongTensor): The input tokens tensor.
        - attention_mask (torch.LongTensor): The attention mask tensor.

        Returns:
        - torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
        """
        outputs = self.model(*args, output_hidden_states=True, **kwargs)
        last_hidden_states = outputs.hidden_states[-1]  # (batch_size, sequence_length, hidden_size)
        proj = self.custom_text_proj(last_hidden_states)
        # normalize l2 norm
        proj = proj / proj.norm(dim=-1, keepdim=True)
        proj = proj * kwargs["attention_mask"].unsqueeze(-1)
        return proj


在這個過程中,我們可以看到其實模型並沒有使用到 logits 的輸出(通常這是由 lm_head(last_hidden_states) 所得到),而是自己重新宣告了 custom_text_proj

self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)


這樣一來,由於模型在每次讀取時都隨機初始化這一層,進而導致預測的分數無法復現。目前開發者在 Inconsistent Scores with Example Inference Script #11 中有提到這個問題,並預計在未來改善;不過當前如果要解決這個問題,可以把某次良好的初始權重保存下來,並在之後讀取模型時再次讀取固定好的權重回來:

# Save 
torch.save(model.custom_text_proj.base_layer, "custom_text_proj.pth")

# Load
model.custom_text_proj.base_layer = torch.load("custom_text_proj.pth")


目前測試,都可以正常地重複預測結果。


結語

我認為以文搜圖這是多模態模型非常強力的應用場景,接下來相關的應用只會多不會少吧!而且、總比讓 AI 模型替我們解釋搞笑圖片來得有用多了 XD

下一步,若有機會我會想研究關於 ColPali 的微調,不過在那之前應該仍會優先關注於 LLM 微調技巧和加速推理任務。期待之後有更多的時間與 GPU(尤其是 GPU XD,時間就跟乳溝一樣可以擠一擠、但是 GPU 可是非常頑固的!)


References


Read More

Leave a Reply