Last Updated on 2024-11-05 by Clay
介紹
推測性解碼(Speculative Decoding)是一種實用性極強的加速推理技巧,通過讓小模型(draft model)快速、連續地解碼多個 Tokens 並保留過程中的採樣機率分佈,並讓我們真正希望加速的大模型(target model)在此之上預測下一個 Token —— 同時把過往的每個 Token 位置的採樣機率分佈一次性地計算得出,再透過 target model probs 去驗證 draft model probs 的有效性,並接受足夠可靠的 draft model 的推測解碼 Tokens。
比較詳細的原理,或許可以參考我之前閱讀後整理的 Google 團隊發表的 Speculative Decoding 論文:[論文閱讀] Fast Inference from Transformers via Speculative Decoding
換算成程式,其實概念非常簡單。假設 draft model 解碼了 k 個 tokens,我們依序比較這些 tokens 的機率與 target model 解碼出這些 tokens 的機率,並存在兩種狀況:
- draft model token prob <= target model token prob:必定接受 draft model 的解碼,因為 target model 只會有更高的機率解碼出此 token
- draft model token prob > target model token prob:我們使用 1 - (target model token prob / draft model token prob) 的機率拒絕此 token
實作
以下是我對於模型計算出 logits
後的採樣參數設計,跟 HuggingFace 有些不同,但我暫時是以此來進行採樣的參數控制。詳細的採樣參數說明,可以參考我的另外一篇實作實驗:大型語言模型的解碼採樣筆記
在這篇實作筆記中,我會比較嚴謹地區分 logits
和 probs
的定義:logits
是模型原始計算的最後輸出,理論範圍為 (-inf, inf)、probs
由於是機率,為 logits
通過 softmax 計算得到的 (0, 1) 範圍機率分佈。
from typing import Tuple
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def apply_repetition_penalty(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
batch_size, gamma, vocab_size = logits.shape
seq_length = prefix_token_ids.shape[1]
for batch_idx in range(batch_size):
for gamma_idx in range(gamma):
current_prefix = prefix_token_ids[batch_idx, :seq_length - gamma + gamma_idx + 1]
unique_token_ids = set(current_prefix.tolist())
for token_id in unique_token_ids:
if logits[batch_idx, gamma_idx, token_id] > 0:
logits[batch_idx, gamma_idx, token_id] /= repetition_penalty
else:
logits[batch_idx, gamma_idx, token_id] *= repetition_penalty
return logits
def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
if top_k > 0:
values, _ = torch.topk(logits, top_k, dim=-1)
min_values = values[:, :, -1].unsqueeze(dim=-1)
logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)
return logits
def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Find the position of accumulation probs > top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Get at least one element
sorted_indices_to_remove[:, :, 1:] = sorted_indices_to_remove[:, :, :-1].clone()
sorted_indices_to_remove[:, :, 0] = False
# Create the mask that have the same shape of logits
indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
return logits
def sample_next_token(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
eps: float = 1e-7,
probs_num: int = 1,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
curr_logits = logits[:, -probs_num:, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
curr_logits = apply_repetition_penalty(
logits=curr_logits,
prefix_token_ids=prefix_token_ids,
repetition_penalty=repetition_penalty,
)
# Apply temperature
curr_logits = curr_logits / (temperature + eps)
# Apply `top_k`
curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)
# Apply `top_p`
curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)
# Convert logits into probs
probs = torch.softmax(curr_logits, dim=-1)
# Sampling
seq_tokens = []
for seq_idx in range(probs.shape[1]):
seq_token = torch.multinomial(probs[:, seq_idx, :], num_samples=1)
seq_tokens.append(seq_token)
seq_token_ids = torch.cat(seq_tokens, dim=1)
return seq_token_ids, probs
以下,我們正式進入 Speculative Decoding 的實作環節,這裡,我僅僅只測試到驗證環節結束;並且由於 batch_size > 1 的情況會有進度不同的問題(填充成一樣長度可能會造成問題),所以我暫時僅僅只假設 batch_size=1 的情況。
首先,import 我們所有需要使用到的套件。
from typing import Dict, List, Optional, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import LlamaForCausalLM, GPT2TokenizerFast, PreTrainedTokenizerBase
from sampling import sample_next_token
這邊是草稿模型的推測 gamma
個解碼,我同時把 token 和注意力遮罩都拼接回了原本的 inputs
,同時保留了過程所有的 probs。
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_probs = []
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, torch.cat(draft_probs, dim=1)
target model 驗證的部份就比較繁瑣了。雖然它只向前推理一次,但是需要把 gamma + 1 的 logits
全部經過採樣並得到 probs
,同時也要保存每個位置的解碼 token,好在拒絕 draft model 的 Token 時能夠直接替換。
之後按照剛才的說明,比較 target model token prob 和 draft model token prob 的大小,並在特定情況以一定機率進行拒絕。
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
# Concat `input_ids`
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if not acceptance_mask[batch_idx][pos_idx]:
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs
最後,我們來看看實際執行的結果:
if __name__ == "__main__":
# Settings
target_model_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
draft_model_path = "../models/HuggingFaceTB--SmolLM2-135M-Instruct/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load Tokenizer
draft_tokenizer = GPT2TokenizerFast.from_pretrained(draft_model_path)
target_tokenizer = GPT2TokenizerFast.from_pretrained(target_model_path)
# Load Model
draft_model = LlamaForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
target_model = LlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": "What is the capital of Taiwan. And why?",
},
],
]
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = draft_tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Draft model
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=10,
)
print(target_inputs["input_ids"])
print("".join(draft_tokenizer.batch_decode(target_inputs["input_ids"][0])))
# Target model
outputs = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
print(outputs["input_ids"])
print("".join(target_tokenizer.batch_decode(outputs["input_ids"][0])))
Output:
tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 314, 260, 3575, 282, 15914,
30, 1350, 1701, 47, 2, 198, 1, 520, 9531, 198,
504, 3575, 282, 15914, 314, 12545]], device='cuda:0')
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Tai
tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 314, 260, 3575, 282, 15914,
30, 1350, 1701, 47, 2, 198, 1, 520, 9531, 198,
504, 3575, 282, 15914, 314, 12545, 46162]], device='cuda:0')
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei
這是一個快樂的例子,我們由小模型解碼了 10 個 tokens,並且全部被接受,於是大模型還順勢推理了第 11 個 tokens,簡直太賺了。
但如果我們看一個比較模糊的問題,就會看到大模型很快地拒絕了小模型的推測,擺出一副『我行我上!』的架式。
if __name__ == "__main__":
# Settings
target_model_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
draft_model_path = "../models/HuggingFaceTB--SmolLM2-135M-Instruct/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load Tokenizer
draft_tokenizer = GPT2TokenizerFast.from_pretrained(draft_model_path)
target_tokenizer = GPT2TokenizerFast.from_pretrained(target_model_path)
# Load Model
draft_model = LlamaForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
target_model = LlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": "What???",
},
],
]
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = draft_tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Draft model
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=10,
)
print(target_inputs["input_ids"])
print("".join(draft_tokenizer.batch_decode(target_inputs["input_ids"][0])))
# Target model
outputs = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
print(outputs["input_ids"])
print("".join(target_tokenizer.batch_decode(outputs["input_ids"][0])))
Output:
tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 16693, 47, 2, 198, 1,
520, 9531, 198, 22234, 8165, 28, 198, 198, 42519]],
device='cuda:0')
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What???<|im_end|>
<|im_start|>assistant
Hey Sarah,
Hope
tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 16693, 47, 2, 198, 1,
520, 9531, 198, 57]], device='cuda:0')
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What???<|im_end|>
<|im_start|>assistant
I
我們可以看到,大模型在第一個 Token 就拒絕了,重新生成了第一個 Token。不過由於我們最終的目標是讓小模型的生成速度是大模型的非常多倍,所以實際上就算第一個 Token 被拒絕,其損失的時間並不算太多 —— 至少,這是我們的目標。
假以時日,我應會慢慢完善加速推理框架的許多實作,感興趣的話可以瀏覽我的 GitHub:https://github.com/ccs96307/fast-llm-inference
References
- lucidrains/speculative-decoding: Explorations into some ...
- feifeibear/LLMSpeculativeSampling: Fast inference from ...