Skip to content

大型語言模型的解碼採樣筆記

Last Updated on 2024-11-04 by Clay

我們在利用大型語言模型進行生成任務時,尤其是自迴歸任務(Auto-regression),模型實際上是在做一個好幾萬的分類任務,而分類的標的,其實就是我們詞庫(vocabulary)中的詞,通常是被稱為詞元(Token),也就是組成詞彙的最小單位。

如果我們希望採用貪婪解碼(greedy decoding),那麼我們永遠取模型最後一層解碼層的 logits 最大值就完事;但如果我們希望模型的生成結果具備多樣性與一定程度的隨機性,那麼,我們就有了許多的參數可以用來調整 logits 成為機率分佈了。

本篇筆記紀錄的並非如 HuggingFace 等知名框架的標準實作,僅僅只是我個人在實作加速推理框架時的一個實驗性的實現,所以參考其概念即可。


採樣時的參數

採樣的參數其實非常多種,本文只記述最常見的幾種:

  • 重複懲罰(repetition_penalty):曾經出現過的詞彙其解碼的機率下降
  • 溫度(temperature):縮小或放大 logits 中不同 Tokens 之間的差距
  • Top-K:選取前 K 大的 Tokens 進入解碼候選
  • Top-p:選取累積機率在 Top-p 之前的 Token 進入解碼候選

以下,我們都假設我們要解碼的 logits 其形狀為 (batch_size, vocab_size),序列長度的 seq_length 被我直接隱藏,因為我們總是對於最後一層進行解碼與機率分佈的調整。

首先,我們來看看我們解碼時的順序:

def sample_next_token(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    eps: float = 1e-7,
) -> torch.FloatTensor:
    curr_logits = logits[:, -1, :]

    # Apply repetition penalty
    if repetition_penalty != 1.0:
        curr_logits = apply_repetition_penalty(
            logits=curr_logits,
            prefix_token_ids=prefix_token_ids,
            repetition_penalty=repetition_penalty,
        )

    # Apply temperature
    curr_logits = curr_logits / (temperature + eps)

    # Apply `top_k`
    curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)

    # Apply `top_p`
    curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)

    # Convert logits into probs
    probs = torch.softmax(curr_logits, dim=-1)

    # Sampling
    next_token = torch.multinomial(probs, num_samples=1)

    return next_token


當模型的最後一層輸出結果 logits 進入採樣階段時,原先的形狀應為 (batch_size, seq_length, vocab_size),在這裡我將其固定取 seq_length 的最後一層。

接著,按照順序進行重複懲罰、溫度、top-k 採樣、top-p 採樣,最後通過 Softmax 轉換成機率分佈,再由 torch.multinomial 抽樣最後選擇的下一個解碼 Token。

重複懲罰的方式是針對不同的句子,按照之前出現過的 Token,給當前的解碼 logits 對應的 Token 位置乘上(或者除以) repetition_penalty,以降低重複出現的機率。

def apply_repetition_penalty(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
    for batch_idx in range(prefix_token_ids.shape[0]):
        for token_id in set(prefix_token_ids[batch_idx].tolist()):
            if logits[batch_idx, token_id] < 0:
                logits[batch_idx, token_id] *= repetition_penalty
            else:
                logits[batch_idx, token_id] /= repetition_penalty
    return logits


而溫度的調整就非常直觀了,我們直接將 logits 除以溫度,假設溫度在 1.0 以下,就等於放大 Token 之間的距離,讓越大數值的 Token 更容易被解碼;反之若大於 1.0,則是縮小 Token 之間的距離,讓機率比較低的 Token 有更高機率被採樣。

當然,這裡需要注意的是如果溫度參數當分母若為 0 則會發生錯誤,所以我會設定一個 eps 的數值(我的預設值為 1e-7)加入溫度參數來預防。

# Apply temperature
curr_logits = curr_logits / (temperature + eps)



top-k 的採樣也同樣單純,我先在 logits 中根據最後一層找出 top-k 中最小值,再使用 torch.where() 將所有小於 top-k 中最小值的位置通通設定為 -float("Inf")

def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
    if top_k > 0:
        values, indices = torch.topk(logits, top_k, dim=-1)
        min_values = values[:, -1].unsqueeze(dim=-1)
        logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)

    return logits



而 top-p 則相對複雜了些,top-p 的概念是我們按照機率高者一個個取,直到累積採樣的機率高於 top_p 值,剩下的就進行採樣了,是一種取『前幾 %』的概念。

舉個例子,我們有一個機率分佈 [0.4, 0.2, 0.15, 0.15, 0.1],機率加總必定為 1,然後我們設定 top_p 為 0.8,我們的採樣過程如下:

  • 取 0.4,當前累積機率為 0.4
  • 取 0.2,當前累積機率為 0.6
  • 取 0.15,當前累積機率為 0.75
  • 取 0.15,當前累積機率為 0.9 —— 慢著,超過了,抱歉第二個 0.15 機率的這個元素不取了,我們之後採樣的只有 [0.4, 0.2, 0.15] 三個元素。

實作上,我們需要先把 logits 由大到小按照數值排序,同時也要保留排序前的索引,因為之後還要按照索引進行掩碼;之後我們進行 softmax(logits) 過後的累加運算,並把所有累加 > top_p 的 Token 通通設定為 -float("Inf")

def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Find the position of accumulation probs > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    
    # Get at least one element
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = False
    
    # Create the mask that have the same shape of logits
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
    
    return logits



結果比較

以下是我按照固定的超參數進行採樣的比較,比較者是我的實現以及 HuggingFace 的實現(當然,它的內部其實還有別的採樣規則,所以並不等價),使用的模型為 GPT-2。採用的採樣參數設定為:

  • temperature = 0.1
  • top_k = 50
  • top_p = 0.9
  • repetition_penalty = 1.2
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


# Settings
pretrained_model_name_or_path = "openai-community/gpt2"

# Model & Tokenizer
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"


def apply_repetition_penalty(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
    for batch_idx in range(prefix_token_ids.shape[0]):
        for token_id in set(prefix_token_ids[batch_idx].tolist()):
            if logits[batch_idx, token_id] < 0:
                logits[batch_idx, token_id] *= repetition_penalty
            else:
                logits[batch_idx, token_id] /= repetition_penalty
    return logits

def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
    if top_k > 0:
        values, indices = torch.topk(logits, top_k, dim=-1)
        min_values = values[:, -1].unsqueeze(dim=-1)
        logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)

    return logits


def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Find the position of accumulation probs > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    
    # Get at least one element
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = False
    
    # Create the mask that have the same shape of logits
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
    
    return logits


def sample_next_token(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    eps: float = 1e-7,
) -> torch.FloatTensor:
    curr_logits = logits[:, -1, :]

    # Apply repetition penalty
    if repetition_penalty != 1.0:
        curr_logits = apply_repetition_penalty(
            logits=curr_logits,
            prefix_token_ids=prefix_token_ids,
            repetition_penalty=repetition_penalty,
        )

    # Apply temperature
    curr_logits = curr_logits / (temperature + eps)

    # Apply `top_k`
    curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)

    # Apply `top_p`
    curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)

    # Convert logits into probs
    probs = torch.softmax(curr_logits, dim=-1)

    # Sampling
    next_token = torch.multinomial(probs, num_samples=1)

    return next_token


temperature = 0.1
top_k = 50
top_p = 0.9
repetition_penalty = 1.2


# Test data
sentences = [
    "Today is a nice day",
    "How are you?",
]

inputs = tokenizer(
    sentences,
    max_length=512,
    truncation=True,
    padding=True,
    return_tensors="pt",
).to("cuda:0")


print("=== My Sampling ===")
for idx in range(10):
    outputs = model(**inputs)

    next_tokens = sample_next_token(
        outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
    )

    input_ids = torch.cat([inputs["input_ids"], next_tokens], dim=-1)
    attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask


for sent in tokenizer.batch_decode(inputs.input_ids):
    print(sent)

print("\n=== HuggingFace ===")


# Test data
sentences = [
    "Today is a nice day",
    "How are you?",
]

inputs = tokenizer(
    sentences,
    max_length=512,
    truncation=True,
    padding=True,
    return_tensors="pt",
).to("cuda:0")


outputs = model.generate(
    **inputs,
    temperature=temperature,
    top_k=top_k,
    top_p=top_p,
    repetition_penalty=repetition_penalty,
)


for sent in tokenizer.batch_decode(outputs):
    print(sent)


Output:

=== My Sampling ===
Today is a nice day for the world to celebrate our country's independence.
How are you?<|endoftext|>The answer is simple. You're a young man

=== HuggingFace ===
Today is a nice day for the world to celebrate our country's independence.
<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
How are you?<|endoftext|>The answer is simple: You're not a human being. Your brain works


我們可以發現,其實兩者的輸出已經十分接近了。


References


Read More

Leave a Reply