Last Updated on 2024-11-09 by Clay
Introduction
Speculative Decoding is an extremely practical inference acceleration technique that enables a small model (draft model) to rapidly decode multiple tokens and retain the probability distribution of this process. Then, the larger target model, which we aim to accelerate, predicts the next token based on this draft. For each token position, the draft model’s probability distributions are computed and validated using the target model's probabilities, accepting the tokens decoded by the draft model if they are deemed sufficiently reliable.
For a more detailed explanation, you may refer to my previous post summarizing Google's Speculative Decoding paper: [Paper Summary] Fast Inference from Transformers via Speculative Decoding
Translating this into code is fairly straightforward. Suppose the draft model decodes k tokens. We compare the probabilities of each of these tokens with the corresponding probabilities from the target model, resulting in two possible scenarios:
- draft model token prob <= target model token prob: Accept the draft model’s decoded token since the target model would predict the same token with at least as high a probability.
- draft model token prob > target model token prob: Reject the token with a probability of 1 - (target model token prob / draft model token prob).
Implementation
Below is my design for sampling parameters after calculating the logits
from the model. It differs slightly from Hugging Face's approach, but for now, this setup controls the sampling parameters. For a detailed explanation of the sampling parameters, you can refer to my separate implementation post: Notes on Decoding and Sampling for Large Language Models
In this implementation note, I carefully distinguish between logits
and probs
: logits
represent the raw model outputs, theoretically ranging from (-inf, inf), while probs
are the probabilities resulting from applying softmax to the logits, confined to the (0, 1) range.
from typing import Tuple
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def apply_repetition_penalty(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
batch_size, gamma, vocab_size = logits.shape
seq_length = prefix_token_ids.shape[1]
for batch_idx in range(batch_size):
for gamma_idx in range(gamma):
current_prefix = prefix_token_ids[batch_idx, :seq_length - gamma + gamma_idx + 1]
unique_token_ids = set(current_prefix.tolist())
for token_id in unique_token_ids:
if logits[batch_idx, gamma_idx, token_id] > 0:
logits[batch_idx, gamma_idx, token_id] /= repetition_penalty
else:
logits[batch_idx, gamma_idx, token_id] *= repetition_penalty
return logits
def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
if top_k > 0:
values, _ = torch.topk(logits, top_k, dim=-1)
min_values = values[:, :, -1].unsqueeze(dim=-1)
logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)
return logits
def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Find the position of accumulation probs > top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Get at least one element
sorted_indices_to_remove[:, :, 1:] = sorted_indices_to_remove[:, :, :-1].clone()
sorted_indices_to_remove[:, :, 0] = False
# Create the mask that have the same shape of logits
indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
return logits
def sample_next_token(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
eps: float = 1e-7,
probs_num: int = 1,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
curr_logits = logits[:, -probs_num:, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
curr_logits = apply_repetition_penalty(
logits=curr_logits,
prefix_token_ids=prefix_token_ids,
repetition_penalty=repetition_penalty,
)
# Apply temperature
curr_logits = curr_logits / (temperature + eps)
# Apply `top_k`
curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)
# Apply `top_p`
curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)
# Convert logits into probs
probs = torch.softmax(curr_logits, dim=-1)
# Sampling
seq_tokens = []
for seq_idx in range(probs.shape[1]):
seq_token = torch.multinomial(probs[:, seq_idx, :], num_samples=1)
seq_tokens.append(seq_token)
seq_token_ids = torch.cat(seq_tokens, dim=1)
return seq_token_ids, probs
Below, we dive into the Speculative Decoding implementation itself. Here, I only test up to the validation stage; for now, I assume batch_size=1 due to potential issues with padding tokens to match differing lengths across multiple sequences.
First, let’s import all required libraries.
from typing import Dict, List, Optional, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import LlamaForCausalLM, GPT2TokenizerFast, PreTrainedTokenizerBase
from sampling import sample_next_token
Here, we perform draft decoding of gamma
tokens in speculative mode. I concatenate the token IDs and attention mask back to the original inputs
, retaining all probability distributions along the way.
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_probs = []
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, torch.cat(draft_probs, dim=1)
The validation step for the target model is a bit more intricate. While it only runs one forward pass, it needs to sample the probabilities for each of the gamma+1 logits
and retain all decoded tokens. During this step, rejected tokens from the draft model are replaced by the target model’s output.
As described earlier, we compare the token probabilities between the target model and the draft model, rejecting tokens with a certain probability when needed.
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probability 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determine acceptance or rejection
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final rejection masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
# Concat `input_ids`
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if not acceptance_mask[batch_idx][pos_idx]:
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs
Finally, let’s examine the execution results:
if __name__ == "__main__":
# Settings
target_model_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
draft_model_path = "../models/HuggingFaceTB--SmolLM2-135M-Instruct/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load Tokenizer
draft_tokenizer = GPT2TokenizerFast.from_pretrained(draft_model_path)
target_tokenizer = GPT2TokenizerFast.from_pretrained(target_model_path)
# Load Model
draft_model = LlamaForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
target_model = LlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": "What is the capital of Taiwan. And why?",
},
],
]
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = draft_tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Draft model
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=10,
)
print(target_inputs["input_ids"])
print("".join(draft_tokenizer.batch_decode(target_inputs["input_ids"][0])))
# Target model
outputs = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
print(outputs["input_ids"])
print("".join(target_tokenizer.batch_decode(outputs["input_ids"][0])))
Output:
tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 314, 260, 3575, 282, 15914,
30, 1350, 1701, 47, 2, 198, 1, 520, 9531, 198,
504, 3575, 282, 15914, 314, 12545]], device='cuda:0')
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Tai
tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 314, 260, 3575, 282, 15914,
30, 1350, 1701, 47, 2, 198, 1, 520, 9531, 198,
504, 3575, 282, 15914, 314, 12545, 46162]], device='cuda:0')
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei
This is a positive example. Here, the small model decoded 10 tokens, all of which were accepted, and the large model then inferred the 11th token. A great success.
However, if we look at a more ambiguous question, we might see the large model rejecting the small model's draft more quickly, with a bit of an "I’ll take over" attitude.
if __name__ == "__main__":
# Settings
target_model_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
draft_model_path = "../models/HuggingFaceTB--SmolLM2-135M-Instruct/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load Tokenizer
draft_tokenizer = GPT2TokenizerFast.from_pretrained(draft_model_path)
target_tokenizer = GPT2TokenizerFast.from_pretrained(target_model_path)
# Load Model
draft_model = LlamaForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
target_model = LlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": "What???",
},
],
]
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = draft_tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Draft model
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=10,
)
print(target_inputs["input_ids"])
print("".join(draft_tokenizer.batch_decode(target_inputs["input_ids"][0])))
# Target model
outputs = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
print(outputs["input_ids"])
print("".join(target_tokenizer.batch_decode(outputs["input_ids"][0])))
Output:
tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 16693, 47, 2, 198, 1,
520, 9531, 198, 22234, 8165, 28, 198, 198, 42519]],
device='cuda:0')
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What???<|im_end|>
<|im_start|>assistant
Hey Sarah,
Hope
tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 16693, 47, 2, 198, 1,
520, 9531, 198, 57]], device='cuda:0')
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What???<|im_end|>
<|im_start|>assistant
I
As we can see, the large model rejected the first token and regenerated it. However, since the goal is to make the draft model much faster than the target model, even if the first token gets rejected, the time loss should not be significant — at least, that’s the objective.
Over time, I intend to improve several aspects of the speculative decoding framework. Feel free to check out my GitHub if you’re interested: https://github.com/ccs96307/fast-llm-inference
(2024/11/06 Updated)Experimental Results
Initially, I wanted to compare the speed of my Speculative Decoding implementation against direct decoding. As expected, when using the target model alone or the draft model alone, direct decoding was faster than speculative decoding! It turns out that I was directly using the .generate()
method from transformers, which is already heavily optimized. So I modified the sampling method for the entire model to use my custom sampling.
This adjustment yielded results that I found quite satisfying.
from typing import Dict, List, Optional, Tuple
import copy
import time
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from sampling import sample_next_token
def calculate_continuous_acceptance(acceptance_mask):
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_probs = []
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, torch.cat(draft_probs, dim=1)
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probability 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determine acceptance or rejection
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final rejection masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs
Output:
Generate token number: 101
Speculative Decoding Spent Time: 2.8416874408721924 seconds.
Accept Rate: 0.40588235294117647
Generate token number: 100
Normal Target Model Decoding Spent Time: 3.5783841609954834 seconds.
Generate token number: 100
Normal Draft Model Decoding Spent Time: 1.0656843185424805 seconds.
The speculative decoding process took around 2.8 seconds to decode approximately 100 tokens, with an acceptance rate of 0.4.
The target model directly decoded 100 tokens in 3.57 seconds, while the draft model took only 1.06 seconds to decode 100 tokens.
Thus, we see that speculative decoding indeed accelerated the target model's decoding speed! Successful experiment.
References
- lucidrains/speculative-decoding: Explorations into some ...
- feifeibear/LLMSpeculativeSampling: Fast inference from ...