Last Updated on 2024-07-31 by Clay
介紹
從去年開始我就對多模態(Multi-Modal)的 AI 模型充滿了熱忱與好奇,因為我是堅定不移的 AGI 派,認為 AI 目前的潛力仍然遠遠沒有抵達天花板;而 AI 當前的一大瓶頸與研究方向,自然就是整合了多種不同的模態(文字、圖像、音訊...)的模型應用了。
今天我想要紀錄的,是一個名為 ColPali 的多模態模型。當然,文字與圖像整合的模型其實以現在的時局來說是滿爛大街的,不過 ColPali 的架構還是相當巧妙的。
以前我曾經介紹過 ColBERT 這種模型架構:[論文閱讀] ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT
其中最引人深思的,就在於其查找的文字與檢索文件之間細粒度的整合。
這個模型架構並不是像一般的嵌入模型一樣,把 Query 計算一個向量 A、把 Document 計算一個向量 B,然後把兩者計算相似度;它的作法更像是 cross-encoder,Query 的每個 token 都會和 Document 的每個 token 計算相似度的計算,然後取最大的 MaxSim 分數當作該 Query Token 的計分,最後把所有 Query Tokens 加總... 這種作法的好處在於更好地考慮 token 和 token 之間的隱性關聯。
而 ColPali 模型架構也是如此,只是 Query 的 Token 對應的不再是 Document Token,而是 ViT 模型處理圖像的最小單位 —— Patch。
左下粉色區塊的部份,就是圖像經過 Vision LLM 處理成最小單位的 patch hidden state,此處是我們視線在服務後台建立好的圖像索引、可能有非常多待檢索的圖片或文件;而右下綠色的部份,則是即時地問題輸入,通過 LLM 計算出不同 token 的 hidden states,然後與圖像部份的 patch 計算 MaxSim 分數... 之後就跟 ColBERT 的部份一樣了。
利用這種架構,我們就能很好地透過問題檢索出圖片中的某些重要部份,比方說輸入 ears 找到圖片中貓咪的耳朵。
實務上來說,研究團隊所提出的 ColPali 模型是由 Google Zürich 團隊所發布的 PaliGemma 模型微調而來,並透過 Omar Khattab 在 ColBERT 中提出的後期互動機制來計算多向量的檢索,並微調於下游任務中。
使用方式
在使用模型前,我們需要 clone 其 GitHub 倉庫,並安裝所有 Python 相關套件。
git@github.com:illuin-tech/colpali.git
cd colpali/
pip install -r requirements.txt
之後,就可以進行初步的推理(程式碼引用自官方):
import torch
import typer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image
from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from colpali_engine.utils.image_from_page_utils import load_from_dataset
def main() -> None:
"""Example script to run inference with ColPali"""
# Load model
model_name = "vidore/colpali"
model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda").eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name)
# select images -> load_from_pdf(<pdf_path>), load_from_image_urls(["<url_1>"]), load_from_dataset(<path>)
images = load_from_dataset("vidore/docvqa_test_subsampled")
queries = ["From which university does James V. Fiorca come ?", "Who is the japanese prime minister?"]
# run inference - docs
dataloader = DataLoader(
images,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_images(processor, x),
)
ds = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
# run inference - queries
dataloader = DataLoader(
queries,
batch_size=4,
shuffle=False,
collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
)
qs = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
# run evaluation
retriever_evaluator = CustomEvaluator(is_multi_vector=True)
scores = retriever_evaluator.evaluate(qs, ds)
print(scores.argmax(axis=1))
if __name__ == "__main__":
typer.run(main)
繪製注意力熱點圖
在原始論文中有提供一張注意力的熱點圖,直接讓大家能夠對這個技術的能力有了充足信心。
我有個同事已經率先幫我找到了在官方 GitHub repo 中的原始碼位置了,我稍微調整了一下排版,現在來測試一下生成的效果:
import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Tuple, cast
from uuid import uuid4
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import seaborn as sns
import torch
from PIL import Image
from einops import rearrange
from tqdm import trange
from colpali_engine.interpretability.processor import ColPaliProcessor
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
from colpali_engine.models.paligemma_colbert_architecture import ColPali
OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")
@dataclass
class InterpretabilityInput:
query: str
image: Image.Image
start_idx_token: int
end_idx_token: int
def generate_interpretability_plots(
model: ColPali,
processor: ColPaliProcessor,
query: str,
image: Image.Image,
savedir: str | Path | None = None,
add_special_prompt_to_doc: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Sanity checks
if len(model.active_adapters()) != 1:
raise ValueError("The model must have exactly one active adapter.")
if model.config.name_or_path not in VIT_CONFIG:
raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
vit_config = VIT_CONFIG[model.config.name_or_path]
# Handle savepath
if not savedir:
savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
print(f"No savepath provided. Results will be saved to: `{savedir}`.")
elif isinstance(savedir, str):
savedir = Path(savedir)
savedir.mkdir(parents=True, exist_ok=True)
# Resize the image to square
input_image_square = image.resize((vit_config.resolution, vit_config.resolution))
# Preprocess the inputs
input_text_processed = processor.process_text(query).to(model.device)
input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
model.device
)
# Forward pass
with torch.no_grad():
output_text = model(**asdict(input_text_processed)) # (1, n_text_tokens, hidden_dim)
output_image = model(**asdict(input_image_processed)) # (1, n_patch_x * n_patch_y, hidden_dim)
if add_special_prompt_to_doc: # remove the special tokens
output_image = output_image[:, : processor.processor.image_seq_length, :]
n_patches_per_dim = vit_config.resolution // vit_config.patch_size
output_image = rearrange(
output_image, "b (h w) c -> b h w c", h=n_patches_per_dim, w=n_patches_per_dim
)
# Get the unnormalized attention map
attention_map = torch.einsum(
"bnk,bijk->bnij", output_text, output_image
)
attention_map_normalized = normalize_attention_map_per_query_token(attention_map)
# Get text token information
n_tokens = input_text_processed.input_ids.size(1)
text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
print("Text tokens:")
pprint.pprint(text_tokens)
print("\n")
return attention_map, attention_map_normalized, text_tokens
def plot_attention_heatmap(
img: Image.Image,
patch_size: int,
image_resolution: int,
attention_map: npt.NDArray | torch.Tensor,
figsize: Tuple[int, int] = (8, 8),
style: Dict[str, Any] | str | None = None,
show_colorbar: bool = False,
show_axes: bool = False,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Plot a heatmap of the attention map over the image.
The image must be square and `attention_map` must be normalized between 0 and 1.
"""
# Get the number of patches
if image_resolution % patch_size != 0:
raise ValueError("The image resolution must be divisible by the patch size.")
num_patches = image_resolution // patch_size
# Default style
if style is None:
style = {}
# Sanity checks
if isinstance(attention_map, torch.Tensor):
attention_map = cast(npt.NDArray, attention_map.cpu().numpy())
if attention_map.shape != (num_patches, num_patches):
raise ValueError("The shape of the patch_opacities tensor is not correct.")
if not np.all((0 <= attention_map) & (attention_map <= 1)):
raise ValueError("The patch_opacities tensor must have values between 0 and 1.")
# If the image is not square, raise an error
if img.size[0] != img.size[1]:
raise ValueError("The image must be square.")
# Get the image as a numpy array
img_array = np.array(img.convert("RGBA")) # (H, W, C) where the last channel is the alpha channel
# Get the attention map as a numpy array
attention_map_image = Image.fromarray((attention_map * 255).astype("uint8")).resize(
img.size, Image.Resampling.BICUBIC
)
# Create a figure
with plt.style.context(style):
fig, ax = plt.subplots(figsize=figsize)
ax.imshow(img_array)
im = ax.imshow(
attention_map_image,
cmap=sns.color_palette("mako", as_cmap=True),
alpha=0.5,
)
if show_colorbar:
fig.colorbar(im)
if not show_axes:
ax.set_axis_off()
fig.tight_layout()
return fig, ax
if __name__ == "__main__":
# Load model and processor
model_name = "../models/vidore--colpali/"
model = ColPali.from_pretrained("../models/google--paligemma-3b-mix-448", torch_dtype=torch.float16, device_map="cuda").eval()
model.load_adapter(model_name)
processor = ColPaliProcessor.from_pretrained(model_name)
model.config.name_or_path = "google/paligemma-3b-mix-448"
# Load image
image_path = "../Mimi.jpg"
image = Image.open(image_path)
query = "Where are the eyes and the ears of the cat?"
# Generate `attention_map`
attention_map, attention_map_normalized, text_tokens = generate_interpretability_plots(
model=model,
processor=processor,
query=query,
image=image,
savedir="../outputs/",
add_special_prompt_to_doc=True,
)
config = VIT_CONFIG[model.config.name_or_path]
patch_size = config.patch_size
resolution = config.resolution
# Generate attention heatmap
for token_idx in trange(1, len(text_tokens) - 1, desc="Saving attention maps..."):
fig, ax = plot_attention_heatmap(
image,
patch_size,
resolution,
attention_map_normalized[0, token_idx, :, :],
style="dark_background",
show_colorbar=True,
show_axes=False
)
fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
savepath = Path(f"../outputs/mimi_token_{text_tokens[token_idx]}.png")
fig.savefig(savepath)
plt.close(fig)
在這裡我輸入的圖片是我家可愛的咪咪:
輸入的 Query 則是:Where are the eyes and the ears of the cat?
然後我挑選我保存的 mimi_token__ears.png 這張圖片(每個注意力圖都是針對單一 Query Token 下去繪製):
可以發現,模型真的非常精準地關注貓咪的耳朵部份,性能非常不錯。
問題排除 - 無法重現一致的預測分數
在當前的模型架構中,我們輸入的資料無論是圖片還是文字,最後並不會走到 PaliGemma 的最後一層 lm_head
,而是由 ColPali 模型定義的 custom_text_proj
這一個線性層接受 last_hidden_state
,再輸出最後的特徵分佈。
class ColPali(PaliGemmaPreTrainedModel):
def __init__(self, config):
super(ColPali, self).__init__(config=config)
self.model: PaliGemmaForConditionalGeneration = PaliGemmaForConditionalGeneration(config)
self.dim = 128
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
self.main_input_name = "doc_input_ids"
def forward(self, *args, **kwargs):
"""
Forward pass through Llama and the linear layer for dimensionality reduction
Args:
- input_ids (torch.LongTensor): The input tokens tensor.
- attention_mask (torch.LongTensor): The attention mask tensor.
Returns:
- torch.Tensor: Embeddings of shape (batch_size, num_tokens, dim)
"""
outputs = self.model(*args, output_hidden_states=True, **kwargs)
last_hidden_states = outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states)
# normalize l2 norm
proj = proj / proj.norm(dim=-1, keepdim=True)
proj = proj * kwargs["attention_mask"].unsqueeze(-1)
return proj
在這個過程中,我們可以看到其實模型並沒有使用到 logits
的輸出(通常這是由 lm_head(last_hidden_states)
所得到),而是自己重新宣告了 custom_text_proj
。
self.custom_text_proj = nn.Linear(self.model.config.text_config.hidden_size, self.dim)
這樣一來,由於模型在每次讀取時都隨機初始化這一層,進而導致預測的分數無法復現。目前開發者在 Inconsistent Scores with Example Inference Script #11 中有提到這個問題,並預計在未來改善;不過當前如果要解決這個問題,可以把某次良好的初始權重保存下來,並在之後讀取模型時再次讀取固定好的權重回來:
# Save
torch.save(model.custom_text_proj.base_layer, "custom_text_proj.pth")
# Load
model.custom_text_proj.base_layer = torch.load("custom_text_proj.pth")
目前測試,都可以正常地重複預測結果。
結語
我認為以文搜圖這是多模態模型非常強力的應用場景,接下來相關的應用只會多不會少吧!而且、總比讓 AI 模型替我們解釋搞笑圖片來得有用多了 XD
下一步,若有機會我會想研究關於 ColPali 的微調,不過在那之前應該仍會優先關注於 LLM 微調技巧和加速推理任務。期待之後有更多的時間與 GPU(尤其是 GPU XD,時間就跟乳溝一樣可以擠一擠、但是 GPU 可是非常頑固的!)
References
- Paper - ColPali: Efficient Document Retrieval with Vision Language Models
- HuggingFace Model Hub - ColPali