Last Updated on 2024-11-22 by Clay
Many of the inference acceleration techniques I have studied, such as Speculative Decoding, predominantly use a threshold for the confidence scores of the draft model. This threshold determines how many draft tokens should be decoded before passing them to the target model for verification, thereby reducing the extra computational cost when the draft model operates with low confidence.
Although this method indeed achieves dynamic performance optimization, I had a lingering concern when I first encountered it: If the draft model is inherently less capable—especially in scenarios where its understanding is subpar—should I still trust its high confidence scores?
In other words, would it not make more sense to base the decision to activate speculative decoding on the confidence score of the target model? Since the weaker draft model's high confidence might not be reliable (as the target model ultimately validates the results), the target model's low confidence should provide sufficient justification to avoid delegating decoding tasks to the draft model.
To put it more plainly: The draft model's high confidence may simply be unwarranted self-assurance, and this alone shouldn't justify allowing it to continue generating tokens. Conversely, if the target model exhibits low confidence, it signals uncertainty, making it unreasonable to rely on the weaker draft model for predictions.
In conclusion, perhaps I lack sufficient exposure or have not explored enough studies, but I have yet to see a paper debating whether using the draft model's confidence score for speculative decoding is better than relying on the target model's confidence score to enable the speculative mode. That said, in my initial tests, the acceleration effect was noticeable, so I am documenting this as a reference for my research on inference acceleration techniques.
Background Knowledge
First, feel free to check out my GitHub, where I upload many of my implementations to the fast-llm-inference repository.
This note is a follow-up to my previous post: Speculative Decoding Implementation Notes (with Basic Experimental Results).
In the previous post, I documented my implementation of sampling-based Speculative Decoding, along with a simple experiment. Here, I will refine that code and compare the acceleration results by incorporating the target model's confidence score to decide whether to enable speculative decoding.
After running several experiments with the original code, I obtained the following results:
Generate token number: 100
Generate speed: 34.81801971715937 tokens/sec
Speculative Decoding Spent Time: 2.8720760345458984 seconds.
Accept Rate: 0.34054054054054056
Generate token number: 100
Generate speed: 28.07058497562908 tokens/sec
Normal Target Model Decoding Spent Time: 3.562448024749756 seconds.
Generate token number: 100
Generate speed: 94.92253307563831 tokens/sec
Normal Draft Model Decoding Spent Time: 1.0534906387329102 seconds.
In summary:
- Draft model decoding speed: 94.92 tokens/sec
- Target model decoding speed: 28.07 tokens/sec
- Speculative Decoding speed: 34.82 tokens/sec
This achieved approximately 1.24x acceleration.
Using Target Model Confidence to Decide Whether to Enable Speculative Decoding
Motivated by the idea that if the current decoding task is uncertain, I should not burden the weaker draft model, I made the following changes:
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
confidence_score = 0
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
confidence_score = target_probs[:, -1, next_token[0][0]].item()
print(f"Confidence for next token: {confidence_score:.4f}")
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
print(f"Replacement Confidence for next token: {confidence_score:.4f}")
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
# Keep generating if confidence_score is less than confidence threshold
while confidence_score < 0.5:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=1,
)
# Update `confidence_score`
next_token = next_tokens[:, -1:]
confidence_score = target_probs[0, -1, next_token[0][0]].item()
print(f"keep generate confidence_score: {confidence_score:.4f}")
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
if is_end:
break
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
Compared to the original code, I repeatedly retrieved the target model's confidence score during verification. Given that the probability distribution is already obtained post-sampling, this added virtually no additional overhead in my implementation.
Below are the complete source code and experimental setup:
- GPU: RTX 4060 8GB
- OS: Ubuntu 22.04
- Target model: HuggingFaceTB/SmolLM2-1.7B-Instruct
- Draft model: HuggingFaceTB/SmolLM2-135M-Instruct
from typing import Dict, List, Optional, Tuple
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import argparse
import copy
import time
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase
from sampling.sampling import sample_next_token
"""
python speculative_decoding/run_speculative_decoding.py \
--target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
--draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
--device cuda:0 \
--question 'What is the capital of Taiwan. And why?' \
--gamma 5 \
--test_token_num 100
"""
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_probs = []
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, torch.cat(draft_probs, dim=1)
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
confidence_score = 0
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
confidence_score = target_probs[:, -1, next_token[0][0]].item()
print(f"Confidence for next token: {confidence_score:.4f}")
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
print(f"Replacement Confidence for next token: {confidence_score:.4f}")
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
# Keep generating if confidence_score is less than confidence threshold
while confidence_score < 0.5:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=1,
)
# Update `confidence_score`
next_token = next_tokens[:, -1:]
confidence_score = target_probs[0, -1, next_token[0][0]].item()
print(f"keep generate confidence_score: {confidence_score:.4f}")
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
if is_end:
break
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
def run_test(args) -> None:
# Device
device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
print(device)
# Model path
target_model_path = args.target_model_path
draft_model_path = args.draft_model_path
# Load Tokenizer
draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
# Load Model
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": args.question,
},
],
]
input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = draft_tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Warm up the model (CUDA)
inputs_dummy = {k: v.clone() for k, v in inputs.items()}
with torch.no_grad():
draft_model(**inputs_dummy)
target_model(**inputs_dummy)
torch.cuda.synchronize()
is_end = False
# Record
raw_inputs = copy.deepcopy(inputs)
raw_token_num = raw_inputs["input_ids"].shape[1]
start_time = time.time()
total_draft_tokens = 0
total_accept_tokens = 0
gamma = args.gamma
max_new_tokens = args.test_token_num
while not is_end:
# Draft model
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=inputs,
gamma=gamma,
)
total_draft_tokens += gamma
# Target model
outputs, is_end, accept_tokens = target_speculative_decode(
target_model=target_model,
target_tokenizer=target_tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
break
generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
spent_time = time.time() - start_time
print(f"Generate token number: {generate_token_num}")
print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
# Normal Target Model Speed
raw_inputs = copy.deepcopy(inputs)
start_time = time.time()
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=target_model,
draft_tokenizer=draft_tokenizer,
inputs=raw_inputs,
gamma=args.test_token_num,
)
spent_time = time.time() - start_time
print(f"Generate token number: {max_new_tokens}")
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")
# Normal Draft Model Speed
raw_inputs = copy.deepcopy(inputs)
start_time = time.time()
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=draft_model,
draft_tokenizer=draft_tokenizer,
inputs=raw_inputs,
gamma=args.test_token_num,
)
spent_time = time.time() - start_time
print(f"Generate token number: {max_new_tokens}")
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
print(f"Normal Draft Model Decoding Spent Time: {spent_time} seconds.\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--target_model_path", type=str, default="HuggingFaceTB/SmolLM2-1.7B-Instruct")
parser.add_argument("--draft_model_path", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--question", type=str, default="What is the capital of Taiwan. And why?")
parser.add_argument("--gamma", type=int, default=5)
parser.add_argument("--test_token_num", type=int, default=100)
args = parser.parse_args()
run_test(args)
Output:
Generate token number: 102
Generate speed: 46.418809914955794 tokens/sec
Speculative Decoding Spent Time: 2.19738507270813 seconds.
Accept Rate: 0.5545454545454546
Generate token number: 100
Generate speed: 27.916420540976226 tokens/sec
Normal Target Model Decoding Spent Time: 3.5821211338043213 seconds.
Generate token number: 100
Generate speed: 96.10154773224576 tokens/sec
Normal Draft Model Decoding Spent Time: 1.0405659675598145 seconds.
To present the results more clearly:
Draft Model | Target Model | Speculative Decoding | Total Acceleration | |
---|---|---|---|---|
Original | 94.92 tokens/sec | 28.07 tokens/sec | 34.82 tokens/sec | 1.24x |
Target Threshold | 96.10 tokens/sec | 27.92 tokens/sec | 46.42 tokens/sec | 1.66x |
In my intuition, relying on the target model's confidence to decide whether to enable speculative decoding is a robust yet potentially limited strategy. However, for someone like me with limited training resources, achieving steady acceleration improvements is crucial.
References
- Fast Inference from Transformers via Speculative Decoding
- EAGLE-2: Faster Inference of Language Models with Dynamic Draft Trees