Last Updated on 2025-07-01 by Clay
Introduction
Speculative Decoding and KV Cache are both acceleration techniques applicable to Transformer models. The former uses a faster draft model to speculatively generate several subsequent tokens, which are then validated in a batch by the target model to reduce the cost of autoregressive decoding. The latter leverages the causal attention mechanism of Transformers—where past tokens do not attend to future tokens—to cache previously computed results and avoid redundant calculations during inference.
For more details, feel free to check out my previous notes:
- Implementation Notes on Speculative Decoding (with Simple Experimental Results)
- KV Cache: A Caching Mechanism to Accelerate Transformer Generation
These two techniques can be used together. Below is the thought process and implementation steps I took.
KV Cache in Speculative Decoding
Since the transformers library developed by HuggingFace will deprecate the older tuple-style KV Cache format in version 4.47.0, switching instead to the new DynamicCache() class, I adopted DynamicCache() in this implementation. I found that its crop() method makes it quite convenient to discard unneeded KV Cache segments.
This allows us to let the model automatically handle KV Cache updates during inference. We only need to decide whether to truncate certain parts of the KV Cache based on the validation result of Speculative Decoding.
Generally, during Speculative Decoding, we encounter the following four scenarios:
- When using KV Cache, we only need to input the last token of the sequence into the model. The
input_idsshape changes from(batch_size, seq_len)to(batch_size, 1) - When the target model evaluates
gammaspeculative tokens generated by the draft model, the input shape becomes(batch_size, gamma+1) - If the target model rejects all but
kof thegammatokens generated by the draft model:- The draft model’s accumulated KV Cache of
raw_kv_cache_length + gammaneeds to be cropped toraw_kv_cache_length + k - The KV Cache of
raw_kv_cache_length + gamma + 1needs to be cropped toraw_kv_cache_length + k
- The draft model’s accumulated KV Cache of
- When the target model accepts all of the draft model’s speculative tokens, the next speculative decoding step will have one additional token generated by the target model, so the input shape becomes
(batch_size, 2)
By handling the above scenarios correctly, we can successfully implement Speculative Decoding with KV Cache.
Implementation Details
First, we need to import all necessary packages. The sample_next_token function is a custom sampling function I implemented. You can refer to this article for details: Decoding Sampling Notes for Large Language Models
from typing import Dict, List, Optional, Tuple, Union
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import argparse
import copy
import time
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, Cache, DynamicCache, PreTrainedTokenizerBase
from sampling.sampling import sample_next_token
"""
python speculative_decoding/run_speculative_decoding.py \
--target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
--draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
--device cuda:0 \
--question 'What is the capital of Taiwan. And why?' \
--gamma 5 \
--test_token_num 100
"""
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
Next, we implement the sampling step for the draft model. Here, we dynamically reduce the KV Cache or increase the input length based on the current input length (default KV Cache input shape is 1).
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor, Optional[Union[Cache, List[torch.FloatTensor]]]]:
draft_probs = []
for idx in range(gamma):
raw_inputs_ids = inputs.input_ids
if isinstance(past_key_values, Cache) and past_key_values.get_seq_length() > 0:
distance = inputs.input_ids.shape[1] - past_key_values.get_seq_length()
if distance >= 1:
inputs.input_ids = inputs.input_ids[:, -distance:]
else:
past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
inputs.input_ids = inputs.input_ids[:, -1:]
with torch.no_grad():
outputs = draft_model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
past_key_values=past_key_values,
use_cache=past_key_values is not None,
)
past_key_values = outputs.past_key_values
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs.input_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([raw_inputs_ids, next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs.attention_mask, torch.ones(inputs.attention_mask.shape[0], 1).to(inputs.input_ids.device)], dim=-1)
inputs.input_ids = input_ids
inputs.attention_mask = attention_mask
return inputs, torch.cat(draft_probs, dim=1), past_key_values
The following part handles the validation step with the target model. As with the draft model, we adjust the KV Cache based on input length. The key difference is that if the target model rejects part of the speculative tokens, its own KV Cache must also be rolled back accordingly.
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
) -> Tuple[Dict[str, torch.Tensor], bool, int, Optional[Union[Cache, List[torch.FloatTensor]]]]:
raw_inputs_ids = inputs.input_ids
if isinstance(past_key_values, Cache) and past_key_values.get_seq_length() > 0:
distance = inputs.input_ids.shape[1] - past_key_values.get_seq_length()
if distance >= 1:
inputs.input_ids = inputs.input_ids[:, -distance:]
else:
past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
inputs.input_ids = inputs.input_ids[:, -1:]
with torch.no_grad():
outputs = target_model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
past_key_values=past_key_values,
use_cache=past_key_values is not None,
)
past_key_values = outputs.past_key_values
inputs.input_ids = raw_inputs_ids
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
diff_probs=draft_probs,
prefix_token_ids=inputs.input_ids,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs.input_ids[:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
if torch.all(acceptance_mask):
inputs.input_ids = torch.cat([inputs.input_ids, next_token], dim=-1)
inputs.attention_mask = torch.cat([inputs.attention_mask, torch.ones(inputs.attention_mask.shape[0], 1).to(inputs.input_ids.device)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs.input_ids.shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs.input_ids[batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs.input_ids[batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs.input_ids[batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs.attention_mask[batch_idx][:start_idx+pos_idx+1])
is_end = inputs.input_ids[batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs.input_ids = input_ids
inputs.attention_mask = attention_mask
if isinstance(past_key_values, Cache) and inputs.input_ids.shape[1] <= past_key_values.get_seq_length():
past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask), past_key_values
Finally, let’s compare the performance of Speculative Decoding with and without KV Cache enabled:
def run_test(args) -> None:
# Device
device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Model path
target_model_path = args.target_model_path
draft_model_path = args.draft_model_path
# Load Tokenizer
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
# Load Model
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": args.question,
},
],
]
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = draft_tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Warm up the model (CUDA)
inputs_dummy = {k: v.clone() for k, v in inputs.items()}
with torch.no_grad():
draft_model(**inputs_dummy)
target_model(**inputs_dummy)
torch.cuda.synchronize()
is_end = False
# Record
raw_inputs = copy.deepcopy(inputs)
raw_token_num = raw_inputs.input_ids.shape[1]
start_time = time.time()
total_draft_tokens = 0
total_accept_tokens = 0
gamma = args.gamma
max_new_tokens = args.test_token_num
draft_past_key_values = None
target_past_key_values = None
while not is_end:
# Draft model
target_inputs, draft_probs, draft_past_key_values = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=gamma,
temperature=0,
past_key_values=draft_past_key_values,
)
total_draft_tokens += gamma
# Target model
outputs, is_end, accept_tokens, target_past_key_values = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
temperature=0,
past_key_values=target_past_key_values,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs.input_ids.shape[1] - raw_token_num >= max_new_tokens:
break
generate_token_num = outputs.input_ids.shape[1] - raw_token_num
spent_time = time.time() - start_time
print(f"(Without KV Cache) Generate token number: {generate_token_num}")
print(f"(Without KV Cache) Generate speed: {generate_token_num / spent_time} tokens/sec")
print(f"(Without KV Cache) Speculative Decoding Spent Time: {spent_time} seconds.")
print(f"(Without KV Cache) Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
# KV Cache Speculative Decoding
is_end = False
# Record
inputs = copy.deepcopy(raw_inputs)
raw_token_num = inputs.input_ids.shape[1]
start_time = time.time()
total_draft_tokens = 0
total_accept_tokens = 0
gamma = args.gamma
max_new_tokens = args.test_token_num
draft_past_key_values = DynamicCache()
target_past_key_values = DynamicCache()
while not is_end:
# Draft model
target_inputs, draft_probs, draft_past_key_values = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=gamma,
temperature=0,
past_key_values=draft_past_key_values,
)
total_draft_tokens += gamma
# Target model
outputs, is_end, accept_tokens, target_past_key_values = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
temperature=0,
past_key_values=target_past_key_values,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs.input_ids.shape[1] - raw_token_num >= max_new_tokens:
break
generate_token_num = outputs.input_ids.shape[1] - raw_token_num
spent_time = time.time() - start_time
print(f"(KV Cache) Generate token number: {generate_token_num}")
print(f"(KV Cache) Generate speed: {generate_token_num / spent_time} tokens/sec")
print(f"(KV Cache) Speculative Decoding Spent Time: {spent_time} seconds.")
print(f"(KV Cache) Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
Output:
(Without KV Cache) Generate token number: 101
(Without KV Cache) Generate speed: 51.32459081142281 tokens/sec
(Without KV Cache) Speculative Decoding Spent Time: 1.9678676128387451 seconds.
(Without KV Cache) Accept Rate: 0.7719298245614035
(KV Cache) Generate token number: 101
(KV Cache) Generate speed: 62.468003457069095 tokens/sec
(KV Cache) Speculative Decoding Spent Time: 1.6168277263641357 seconds.
(KV Cache) Accept Rate: 0.8035714285714286
As we can see, performance has indeed improved.
References
- https://pytorch.org/blog/hitchhikers-guide-speculative-decoding/
- Leveraging Speculative Sampling and KV-Cache Optimizations Together for Generative AI using OpenVINO