Last Updated on 2024-11-21 by Clay
目前我看的許多加速推理技巧,如 Speculative Decoding 等等方式,大多數都是採用把 draft model 信心分數設定一個閾值(threshold)來決定現在要解碼多少個 draft tokens、再交由 target model 進行驗證,以此來減少 draft model 在低信心程度的情況下額外多推測的時間開銷。
這個方法雖然確實能夠透過動態來增進效能,但我在第一次看到時其實心中就有一個疑慮:如果今天我的 draft model 與 target model 的能力確實有落差、尤其 draft model 理解能力不夠的情況下,就算 draft model 的信心分數足夠高,我依然該相信它嗎?
所以反過來說,會不會我們採用 target model 的信心分數來當作啟用草稿推測的依據會更合理呢?既然我們會懷疑性能較弱的 draft model 高信心分數不可靠(畢竟做出接受決定的是 target model),那麼 target model 的低信心分數應該有充足理由讓我們不把當前的解碼任務交給 draft model 吧?
更白話一點:我假設 draft model 的高信心分數可能是自我感覺良好,不應構成我們讓它繼續生成的理由;反之若是 target model 的信心分數很低,則代表著自己都沒把握解碼得不錯,所以更不該交給 draft model 去進行推測。
先說結論,可能是我孤陋寡聞與學習涉獵不夠,所以還沒看到有相關論文討論到底取用 draft model 的信心分數決定繼續解碼、還是由 target model 的信心分數決定啟用草稿推測模式 —— Which one is better。但是在我測試的單純情況中,我看到很明確也很直接的加速,於是紀錄於此,權作一些自己研究加速推理技巧的參考。
背景知識
首先,歡迎持續關注我的 GitHub,我的許多實作都會放在這個 fast-llm-inference 的實作專案裡面:https://github.com/ccs96307/fast-llm-inference
這篇筆記嚴格說起來還有上一篇,那就是:推測性解碼(Speculative Decoding)實作筆記(附簡易實驗結果)
這上一篇文章中,我紀錄了我自己實現的採樣方式以及 Speculative Decoding 並附上了一段簡易的實驗結果。現在,我將基於這段程式碼進行改進,並比較加入 target model 信心分數作為判斷是否啟用草稿驗證模式的加速實驗結果。
我用當時的程式碼跑了幾次實驗,得到的結果大抵都相去不遠:
Generate token number: 100
Generate speed: 34.81801971715937 tokens/sec
Speculative Decoding Spent Time: 2.8720760345458984 seconds.
Accept Rate: 0.34054054054054056
Generate token number: 100
Generate speed: 28.07058497562908 tokens/sec
Normal Target Model Decoding Spent Time: 3.562448024749756 seconds.
Generate token number: 100
Generate speed: 94.92253307563831 tokens/sec
Normal Draft Model Decoding Spent Time: 1.0534906387329102 seconds.
簡單講結論就是:
- draft model 的解碼速度:94.92 tokens/sec
- target model 的解碼速度:28.07 tokens/sec
- Speculative Decoding 的解碼速度:34.82 tokens/sec
大約實現了 1.24x 的加速。
使用 target model 的信心分數決定是否要啟用 draft model 進行推測
本著如果當前解碼對我來說沒有把握,那我就不要麻煩更弱的 draft model 來做 —— 我進行了以下的改動:
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
confidence_score = 0
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
confidence_score = target_probs[:, -1, next_token[0][0]].item()
print(f"Confidence for next token: {confidence_score:.4f}")
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
print(f"Replacement Confidence for next token: {confidence_score:.4f}")
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
# Keep generating if confidence_score is less than confidence threshold
while confidence_score < 0.5:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=1,
)
# Update `confidence_score`
next_token = next_tokens[:, -1:]
confidence_score = target_probs[0, -1, next_token[0][0]].item()
print(f"keep generate confidence_score: {confidence_score:.4f}")
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
if is_end:
break
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
跟原本的程式碼相比,我在 target model 驗證時反覆取得信心分數,這對採樣結束後已經取得機率分佈的我的實作來說幾乎沒有額外時間開銷了。
附上我的完整原始碼與實驗配置:
- GPU: RTX 4060 8GB
- Target model: HuggingFaceTB/SmolLM2-1.7B-Instruct
- Draft model: HuggingFaceTB/SmolLM2-135M-Instruct
from typing import Dict, List, Optional, Tuple
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import argparse
import copy
import time
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from sampling.sampling import sample_next_token
"""
python speculative_decoding/run_speculative_decoding.py \
--target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
--draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
--device cuda:0 \
--question 'What is the capital of Taiwan. And why?' \
--gamma 5 \
--test_token_num 100
"""
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_probs = []
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, torch.cat(draft_probs, dim=1)
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
confidence_score = 0
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
confidence_score = target_probs[:, -1, next_token[0][0]].item()
print(f"Confidence for next token: {confidence_score:.4f}")
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
print(f"Replacement Confidence for next token: {confidence_score:.4f}")
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
# Keep generating if confidence_score is less than confidence threshold
while confidence_score < 0.5:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=1,
)
# Update `confidence_score`
next_token = next_tokens[:, -1:]
confidence_score = target_probs[0, -1, next_token[0][0]].item()
print(f"keep generate confidence_score: {confidence_score:.4f}")
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
if is_end:
break
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
def run_test(args) -> None:
# Device
device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
print(device)
# Model path
target_model_path = args.target_model_path
draft_model_path = args.draft_model_path
# Load Tokenizer
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
# Load Model
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": args.question,
},
],
]
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = draft_tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Warm up the model (CUDA)
inputs_dummy = {k: v.clone() for k, v in inputs.items()}
with torch.no_grad():
draft_model(**inputs_dummy)
target_model(**inputs_dummy)
torch.cuda.synchronize()
is_end = False
# Record
raw_inputs = copy.deepcopy(inputs)
raw_token_num = raw_inputs["input_ids"].shape[1]
start_time = time.time()
total_draft_tokens = 0
total_accept_tokens = 0
gamma = args.gamma
max_new_tokens = args.test_token_num
while not is_end:
# Draft model
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=gamma,
)
total_draft_tokens += gamma
# Target model
outputs, is_end, accept_tokens = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
break
generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
spent_time = time.time() - start_time
print(f"Generate token number: {generate_token_num}")
print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
# Normal Target Model Speed
raw_inputs = copy.deepcopy(inputs)
start_time = time.time()
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=target_model,
draft_tokenizer=draft_tokenizer,
inputs=raw_inputs,
gamma=args.test_token_num,
)
spent_time = time.time() - start_time
print(f"Generate token number: {max_new_tokens}")
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")
# Normal Draft Model Speed
raw_inputs = copy.deepcopy(inputs)
start_time = time.time()
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=raw_inputs,
gamma=args.test_token_num,
)
spent_time = time.time() - start_time
print(f"Generate token number: {max_new_tokens}")
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
print(f"Normal Draft Model Decoding Spent Time: {spent_time} seconds.\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--target_model_path", type=str, default="HuggingFaceTB/SmolLM2-1.7B-Instruct")
parser.add_argument("--draft_model_path", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--question", type=str, default="What is the capital of Taiwan. And why?")
parser.add_argument("--gamma", type=int, default=5)
parser.add_argument("--test_token_num", type=int, default=100)
args = parser.parse_args()
run_test(args)
Output:
Generate token number: 102
Generate speed: 46.418809914955794 tokens/sec
Speculative Decoding Spent Time: 2.19738507270813 seconds.
Accept Rate: 0.5545454545454546
Generate token number: 100
Generate speed: 27.916420540976226 tokens/sec
Normal Target Model Decoding Spent Time: 3.5821211338043213 seconds.
Generate token number: 100
Generate speed: 96.10154773224576 tokens/sec
Normal Draft Model Decoding Spent Time: 1.0405659675598145 seconds.
直接做表比較的話結果就很明顯:
draft model | target model | Speculative Decoding | Total Acceleration | |
---|---|---|---|---|
Original | 94.92 tokens/sec | 28.07 tokens/sec | 34.82 tokens/sec | 1.24x |
Target Threshold | 96.10 tokens/sec | 27.92 tokens/sec | 46.42 tokens/sec | 1.66x |
讓我講我的直覺的話,採用 target model 的信心分數來決定是否啟用草稿推測,是個穩健但天花板可能不高的策略。但對於訓練資源缺乏的我來說,可以穩穩地提昇加速效果是很重要的一件事情。
References
- Fast Inference from Transformers via Speculative Decoding
- EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees