Last Updated on 2024-11-04 by Clay
我們在利用大型語言模型進行生成任務時,尤其是自迴歸任務(Auto-regression),模型實際上是在做一個好幾萬的分類任務,而分類的標的,其實就是我們詞庫(vocabulary)中的詞,通常是被稱為詞元(Token),也就是組成詞彙的最小單位。
如果我們希望採用貪婪解碼(greedy decoding),那麼我們永遠取模型最後一層解碼層的 logits
最大值就完事;但如果我們希望模型的生成結果具備多樣性與一定程度的隨機性,那麼,我們就有了許多的參數可以用來調整 logits
成為機率分佈了。
本篇筆記紀錄的並非如 HuggingFace 等知名框架的標準實作,僅僅只是我個人在實作加速推理框架時的一個實驗性的實現,所以參考其概念即可。
採樣時的參數
採樣的參數其實非常多種,本文只記述最常見的幾種:
- 重複懲罰(repetition_penalty):曾經出現過的詞彙其解碼的機率下降
- 溫度(temperature):縮小或放大
logits
中不同 Tokens 之間的差距 - Top-K:選取前 K 大的 Tokens 進入解碼候選
- Top-p:選取累積機率在 Top-p 之前的 Token 進入解碼候選
以下,我們都假設我們要解碼的 logits
其形狀為 (batch_size, vocab_size)
,序列長度的 seq_length 被我直接隱藏,因為我們總是對於最後一層進行解碼與機率分佈的調整。
首先,我們來看看我們解碼時的順序:
def sample_next_token(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
eps: float = 1e-7,
) -> torch.FloatTensor:
curr_logits = logits[:, -1, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
curr_logits = apply_repetition_penalty(
logits=curr_logits,
prefix_token_ids=prefix_token_ids,
repetition_penalty=repetition_penalty,
)
# Apply temperature
curr_logits = curr_logits / (temperature + eps)
# Apply `top_k`
curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)
# Apply `top_p`
curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)
# Convert logits into probs
probs = torch.softmax(curr_logits, dim=-1)
# Sampling
next_token = torch.multinomial(probs, num_samples=1)
return next_token
當模型的最後一層輸出結果 logits
進入採樣階段時,原先的形狀應為 (batch_size, seq_length, vocab_size)
,在這裡我將其固定取 seq_length
的最後一層。
接著,按照順序進行重複懲罰、溫度、top-k 採樣、top-p 採樣,最後通過 Softmax 轉換成機率分佈,再由 torch.multinomial
抽樣最後選擇的下一個解碼 Token。
重複懲罰的方式是針對不同的句子,按照之前出現過的 Token,給當前的解碼 logits
對應的 Token 位置乘上(或者除以) repetition_penalty
,以降低重複出現的機率。
def apply_repetition_penalty(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
for batch_idx in range(prefix_token_ids.shape[0]):
for token_id in set(prefix_token_ids[batch_idx].tolist()):
if logits[batch_idx, token_id] < 0:
logits[batch_idx, token_id] *= repetition_penalty
else:
logits[batch_idx, token_id] /= repetition_penalty
return logits
而溫度的調整就非常直觀了,我們直接將 logits
除以溫度,假設溫度在 1.0 以下,就等於放大 Token 之間的距離,讓越大數值的 Token 更容易被解碼;反之若大於 1.0,則是縮小 Token 之間的距離,讓機率比較低的 Token 有更高機率被採樣。
當然,這裡需要注意的是如果溫度參數當分母若為 0 則會發生錯誤,所以我會設定一個 eps
的數值(我的預設值為 1e-7
)加入溫度參數來預防。
# Apply temperature
curr_logits = curr_logits / (temperature + eps)
top-k 的採樣也同樣單純,我先在 logits
中根據最後一層找出 top-k 中最小值,再使用 torch.where()
將所有小於 top-k 中最小值的位置通通設定為 -float("Inf")
。
def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
if top_k > 0:
values, indices = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(dim=-1)
logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)
return logits
而 top-p 則相對複雜了些,top-p 的概念是我們按照機率高者一個個取,直到累積採樣的機率高於 top_p 值,剩下的就進行採樣了,是一種取『前幾 %』的概念。
舉個例子,我們有一個機率分佈 [0.4, 0.2, 0.15, 0.15, 0.1],機率加總必定為 1,然後我們設定 top_p
為 0.8,我們的採樣過程如下:
- 取 0.4,當前累積機率為 0.4
- 取 0.2,當前累積機率為 0.6
- 取 0.15,當前累積機率為 0.75
- 取 0.15,當前累積機率為 0.9 —— 慢著,超過了,抱歉第二個 0.15 機率的這個元素不取了,我們之後採樣的只有
[0.4, 0.2, 0.15]
三個元素。
實作上,我們需要先把 logits
由大到小按照數值排序,同時也要保留排序前的索引,因為之後還要按照索引進行掩碼;之後我們進行 softmax(logits)
過後的累加運算,並把所有累加 > top_p
的 Token 通通設定為 -float("Inf")
。
def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Find the position of accumulation probs > top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Get at least one element
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
# Create the mask that have the same shape of logits
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
return logits
結果比較
以下是我按照固定的超參數進行採樣的比較,比較者是我的實現以及 HuggingFace 的實現(當然,它的內部其實還有別的採樣規則,所以並不等價),使用的模型為 GPT-2。採用的採樣參數設定為:
- temperature = 0.1
- top_k = 50
- top_p = 0.9
- repetition_penalty = 1.2
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Settings
pretrained_model_name_or_path = "openai-community/gpt2"
# Model & Tokenizer
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
def apply_repetition_penalty(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
for batch_idx in range(prefix_token_ids.shape[0]):
for token_id in set(prefix_token_ids[batch_idx].tolist()):
if logits[batch_idx, token_id] < 0:
logits[batch_idx, token_id] *= repetition_penalty
else:
logits[batch_idx, token_id] /= repetition_penalty
return logits
def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
if top_k > 0:
values, indices = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(dim=-1)
logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)
return logits
def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Find the position of accumulation probs > top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Get at least one element
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
# Create the mask that have the same shape of logits
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
return logits
def sample_next_token(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
eps: float = 1e-7,
) -> torch.FloatTensor:
curr_logits = logits[:, -1, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
curr_logits = apply_repetition_penalty(
logits=curr_logits,
prefix_token_ids=prefix_token_ids,
repetition_penalty=repetition_penalty,
)
# Apply temperature
curr_logits = curr_logits / (temperature + eps)
# Apply `top_k`
curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)
# Apply `top_p`
curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)
# Convert logits into probs
probs = torch.softmax(curr_logits, dim=-1)
# Sampling
next_token = torch.multinomial(probs, num_samples=1)
return next_token
temperature = 0.1
top_k = 50
top_p = 0.9
repetition_penalty = 1.2
# Test data
sentences = [
"Today is a nice day",
"How are you?",
]
inputs = tokenizer(
sentences,
max_length=512,
truncation=True,
padding=True,
return_tensors="pt",
).to("cuda:0")
print("=== My Sampling ===")
for idx in range(10):
outputs = model(**inputs)
next_tokens = sample_next_token(
outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
input_ids = torch.cat([inputs["input_ids"], next_tokens], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
for sent in tokenizer.batch_decode(inputs.input_ids):
print(sent)
print("\n=== HuggingFace ===")
# Test data
sentences = [
"Today is a nice day",
"How are you?",
]
inputs = tokenizer(
sentences,
max_length=512,
truncation=True,
padding=True,
return_tensors="pt",
).to("cuda:0")
outputs = model.generate(
**inputs,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
for sent in tokenizer.batch_decode(outputs):
print(sent)
Output:
=== My Sampling ===
Today is a nice day for the world to celebrate our country's independence.
How are you?<|endoftext|>The answer is simple. You're a young man
=== HuggingFace ===
Today is a nice day for the world to celebrate our country's independence.
<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
How are you?<|endoftext|>The answer is simple: You're not a human being. Your brain works
我們可以發現,其實兩者的輸出已經十分接近了。