Last Updated on 2024-12-17 by Clay
前言
Speculative Decoding 和 KV Cache 都是 Transformers 可以應用的加速技巧;前者是利用一個推理速度較快的 draft model 推測性地生成多個後續的解碼結果並讓希望加速的 target model 進行一次性驗證藉此節省自迴歸解碼的開銷,後者則是應用了 Transformer 因果注意力(Causal Attention)機制中過往 Token 不會看到未來的 Token 的特性,將過去部份 Token 的計算結果保存下來,節省了每次推理時的重複計算。
相關的細節,或可參考我之前寫過的筆記:
而這兩種技術是可以整合在一起進行使用的,以下便是我實作的思路以及過程。
Speculative Decoding 中的 KV Cache
由於 HuggingFace 所開發的 transformers
套件即將在 4.47.0 中捨棄掉舊式以 Tuple 資料型態儲存的 KV Cache,轉而使用自行設計的 DynamicCache()
類別來維護,所以我也在這次的實作中採用了 DynamicCache()
。而使用之後,發現其 crop()
方法可以輕易地呼叫用於捨棄不必要的 KV Cache,實在是非常方便。
這樣一來,我們的 KV Cache 的更新(update)可以交給模型推理時自動添加,而我們只需要視 Speculative Decoding 的驗證情況決定是否要截斷部份的 KV Cache。
一般來說,當在 Speculative Decoding 中,我們會遇到以下 4 種情況
- 當我們使用 KV Cache 時,我們只需要傳入模型序列的最後一個輸入即可。本來的
input_ids
從(batch_size, seq_len)
變成(batch_size, 1)
- 當 target model 評估 draft model 的推測解碼(
gamma
個推測解碼)時,我們需要輸入的是(batch_size, gamma+1)
- draft model 本來生成的
gamma
個推測解碼,被 target model 拒絕到只剩k
個- draft model 本來累積的
raw_kv_cache_length + gamma
個序列的 KV Cache,需要剪裁到raw_kv_cache_length + k
個 - draft model 本來累積的
raw_kv_cache_length + gamma + 1
個序列的 KV Cache,需要剪裁到raw_kv_cache_length + k
個
- draft model 本來累積的
- 當 target model 全接受 draft model 的解碼時,draft model 要預測下一批推測解碼,其輸入
input_ids
的形狀會是額外多一個 target model 產生的 token,所以是(batch_size, 2)
只要處理好以上幾種狀況,我們便能實現帶有 KV Cache 的 Speculative Decoding。
實作細節
首先我們需要 import 所有會使用到的套件,其中 sample_next_token 是我自己實作的抽樣函式,具體實現可以參考:大型語言模型的解碼採樣筆記
from typing import Dict, List, Optional, Tuple, Union
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import argparse
import copy
import time
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, Cache, DynamicCache, PreTrainedTokenizerBase
from sampling.sampling import sample_next_token
"""
python speculative_decoding/run_speculative_decoding.py \
--target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
--draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
--device cuda:0 \
--question 'What is the capital of Taiwan. And why?' \
--gamma 5 \
--test_token_num 100
"""
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
接下來則是 draft model 的抽樣,在這裡我們會根據當前輸入的長度,適時地減少 KV Cache 或是增加 input_ids 輸入的長度(KV Cache 常態的輸入尺寸為 1)。
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor, Optional[Union[Cache, List[torch.FloatTensor]]]]:
draft_probs = []
for idx in range(gamma):
raw_inputs_ids = inputs.input_ids
if isinstance(past_key_values, Cache) and past_key_values.get_seq_length() > 0:
distance = inputs.input_ids.shape[1] - past_key_values.get_seq_length()
if distance >= 1:
inputs.input_ids = inputs.input_ids[:, -distance:]
else:
past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
inputs.input_ids = inputs.input_ids[:, -1:]
with torch.no_grad():
outputs = draft_model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
past_key_values=past_key_values,
use_cache=past_key_values is not None,
)
past_key_values = outputs.past_key_values
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs.input_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([raw_inputs_ids, next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs.attention_mask, torch.ones(inputs.attention_mask.shape[0], 1).to(inputs.input_ids.device)], dim=-1)
inputs.input_ids = input_ids
inputs.attention_mask = attention_mask
return inputs, torch.cat(draft_probs, dim=1), past_key_values
接下來則是 target model 的驗證過程,我們也同樣會根據 input_ids 的長度調整 KV Cache;但最不一樣的地方在於,一旦 target model 拒絕了 draft model 的推測解碼,target model 的 KV Cache 也必須 roll-back 回到截斷的位置。
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
) -> Tuple[Dict[str, torch.Tensor], bool, int, Optional[Union[Cache, List[torch.FloatTensor]]]]:
raw_inputs_ids = inputs.input_ids
if isinstance(past_key_values, Cache) and past_key_values.get_seq_length() > 0:
distance = inputs.input_ids.shape[1] - past_key_values.get_seq_length()
if distance >= 1:
inputs.input_ids = inputs.input_ids[:, -distance:]
else:
past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
inputs.input_ids = inputs.input_ids[:, -1:]
with torch.no_grad():
outputs = target_model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
past_key_values=past_key_values,
use_cache=past_key_values is not None,
)
past_key_values = outputs.past_key_values
inputs.input_ids = raw_inputs_ids
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
diff_probs=draft_probs,
prefix_token_ids=inputs.input_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs.input_ids[:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
if torch.all(acceptance_mask):
inputs.input_ids = torch.cat([inputs.input_ids, next_token], dim=-1)
inputs.attention_mask = torch.cat([inputs.attention_mask, torch.ones(inputs.attention_mask.shape[0], 1).to(inputs.input_ids.device)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs.input_ids.shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs.input_ids[batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs.input_ids[batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs.input_ids[batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs.attention_mask[batch_idx][:start_idx+pos_idx+1])
is_end = inputs.input_ids[batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs.input_ids = input_ids
inputs.attention_mask = attention_mask
if isinstance(past_key_values, Cache) and inputs.input_ids.shape[1] <= past_key_values.get_seq_length():
past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask), past_key_values
最後來實際比較一下使用 KV Cache 與沒有使用在 Speculative Decoding 中的速度差異:
def run_test(args) -> None:
# Device
device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Model path
target_model_path = args.target_model_path
draft_model_path = args.draft_model_path
# Load Tokenizer
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
# Load Model
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": args.question,
},
],
]
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = draft_tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Warm up the model (CUDA)
inputs_dummy = {k: v.clone() for k, v in inputs.items()}
with torch.no_grad():
draft_model(**inputs_dummy)
target_model(**inputs_dummy)
torch.cuda.synchronize()
is_end = False
# Record
raw_inputs = copy.deepcopy(inputs)
raw_token_num = raw_inputs.input_ids.shape[1]
start_time = time.time()
total_draft_tokens = 0
total_accept_tokens = 0
gamma = args.gamma
max_new_tokens = args.test_token_num
draft_past_key_values = None
target_past_key_values = None
while not is_end:
# Draft model
target_inputs, draft_probs, draft_past_key_values = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=gamma,
temperature=0,
past_key_values=draft_past_key_values,
)
total_draft_tokens += gamma
# Target model
outputs, is_end, accept_tokens, target_past_key_values = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
temperature=0,
past_key_values=target_past_key_values,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs.input_ids.shape[1] - raw_token_num >= max_new_tokens:
break
generate_token_num = outputs.input_ids.shape[1] - raw_token_num
spent_time = time.time() - start_time
print(f"(Without KV Cache) Generate token number: {generate_token_num}")
print(f"(Without KV Cache) Generate speed: {generate_token_num / spent_time} tokens/sec")
print(f"(Without KV Cache) Speculative Decoding Spent Time: {spent_time} seconds.")
print(f"(Without KV Cache) Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
# KV Cache Speculative Decoding
is_end = False
# Record
inputs = copy.deepcopy(raw_inputs)
raw_token_num = inputs.input_ids.shape[1]
start_time = time.time()
total_draft_tokens = 0
total_accept_tokens = 0
gamma = args.gamma
max_new_tokens = args.test_token_num
draft_past_key_values = DynamicCache()
target_past_key_values = DynamicCache()
while not is_end:
# Draft model
target_inputs, draft_probs, draft_past_key_values = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=gamma,
temperature=0,
past_key_values=draft_past_key_values,
)
total_draft_tokens += gamma
# Target model
outputs, is_end, accept_tokens, target_past_key_values = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
temperature=0,
past_key_values=target_past_key_values,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs.input_ids.shape[1] - raw_token_num >= max_new_tokens:
break
generate_token_num = outputs.input_ids.shape[1] - raw_token_num
spent_time = time.time() - start_time
print(f"(KV Cache) Generate token number: {generate_token_num}")
print(f"(KV Cache) Generate speed: {generate_token_num / spent_time} tokens/sec")
print(f"(KV Cache) Speculative Decoding Spent Time: {spent_time} seconds.")
print(f"(KV Cache) Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
Output:
(Without KV Cache) Generate token number: 101
(Without KV Cache) Generate speed: 51.32459081142281 tokens/sec
(Without KV Cache) Speculative Decoding Spent Time: 1.9678676128387451 seconds.
(Without KV Cache) Accept Rate: 0.7719298245614035
(KV Cache) Generate token number: 101
(KV Cache) Generate speed: 62.468003457069095 tokens/sec
(KV Cache) Speculative Decoding Spent Time: 1.6168277263641357 seconds.
(KV Cache) Accept Rate: 0.8035714285714286
我們可以看到確實有加速。
References
- https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/
- Leveraging Speculative Sampling and KV-Cache Optimizations Together for Generative AI using OpenVINO