Last Updated on 2024-11-15 by Clay
In self-speculative decoding, since our draft model is derived from part of the target model’s network, finding an optimal 'Layer Skip Strategy' is crucial. We need to skip enough layers to achieve meaningful speedup while ensuring the draft model’s speculative decoding is good enough to avoid frequent rejection by the target model.
Today’s implementation focuses on optimizing my previously implemented LayerSkip model using the Bayesian optimization framework Optuna, to determine which layers to skip.
Background Review
One challenge of writing a series of articles is that I often start in the middle, requiring me to add some background details to provide context.
For a detailed explanation of speculative decoding, you can refer to: [Paper Reading] Fast Inference from Transformers via Speculative Decoding and Speculative Decoding Implementation Note (with Simple Experimental Results)
If you want a simple introduction to Bayes’ Theorem, you can check out: A Note of Bayes' Theorem
For my implementation of the LayerSkip model in Self-Speculative Decoding, refer to: Self-Speculative Decoding Implementation: LayerSkip Transformer
Feel free to visit my GitHub repository, where I will continue updating implementations of inference acceleration techniques: https://github.com/ccs96307/fast-llm-inference
Implementation Description
Optuna is a well-known Python package often used for hyperparameter optimization in model training. In simple terms, it utilizes Bayesian optimization to automatically search for defined hyperparameter combinations and evaluates them against the defined objective function to find the optimal combination.
In layman’s terms, when our true objective cannot be differentiated, and the intermediate processes are a black box, we can define the inputs and observe the outputs. Bayesian optimization helps us identify better combinations — though not necessarily the absolute best, as exhaustive enumeration is infeasible.
Can’t we just enumerate the Layer Skip combinations to find the best one? Assuming 20 layers, each with an Attention and MLP, the number of combinations is 2 ^ 40 = 1,099,511,627,776.
Below, I’ve divided the implementation into several functions:
- calculate_continuous_acceptance(): Calculates the acceptance rate of the target model for tokens predicted by the draft model (tokens must be consecutively accepted).
- drafter_speculative_decoding(): Decoding function for the draft model, including outputting multiple consecutive probabilities.
- target_speculative_decoding(): Decoding function for the target model, which validates draft model probabilities.
- objective(): Search function for Optuna.
Although the paper focuses on optimizing for testing speed, I haven’t set up GPUs or test data yet. For now, I tested with a single sentence and a 2B Gemma model, defining the optimization objective as acceptance rate. The real test results will be updated on GitHub and in this post in the future.
Full Implementation
import optuna
from typing import Dict, Optional, Tuple
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import copy
import time
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from layerskip_modeling.modeling_layerskip_gemma2 import LayerSkipGemma2ForCausalLM
from sampling.sampling import sample_next_token
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
draft_mode: bool = True
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_model.set_draft_mode(draft_mode)
draft_probs = []
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
draft_model.set_draft_mode(True)
return inputs, torch.cat(draft_probs, dim=1)
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
target_model.set_draft_mode(False)
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
def objective(trial):
# Define search space, sssume we can skip up to six layers
total_layers = 26
# Determine skip or not for `attn`
skip_attn_layers = []
for i in range(total_layers):
skip = trial.suggest_int(f'skip_attn_layer_{i}', 0, 1)
if skip == 1:
skip_attn_layers.append(i)
# Determine skip or not for `mlp`
skip_mlp_layers = []
for i in range(total_layers):
skip = trial.suggest_int(f'skip_mlp_layer_{i}', 0, 1)
if skip == 1:
skip_mlp_layers.append(i)
# Disable set to 0 both
if len(skip_attn_layers) == 0 and len(skip_mlp_layers) == 0:
raise optuna.TrialPruned()
skip_layer_ids = {
"attn": skip_attn_layers,
"mlp": skip_mlp_layers,
}
# Set the skip strategy
model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
messages = [
[
{
"role": "user",
"content": "What is the capital of Taiwan. And why?",
},
],
]
input_text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
is_end = False
# Record
raw_inputs = copy.deepcopy(inputs)
raw_token_num = raw_inputs["input_ids"].shape[1]
total_draft_tokens = 0
total_accept_tokens = 0
gamma = 5
max_new_tokens = 100
while not is_end:
# Draft model
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=model,
draft_tokenizer=tokenizer,
inputs=inputs,
gamma=gamma,
)
total_draft_tokens += gamma
# Target model
outputs, is_end, accept_tokens = target_speculative_decode(
target_model=model,
target_tokenizer=tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
break
# Compute acceptance rate
accept_rate = total_accept_tokens / total_draft_tokens
print(f"attn_skip: {skip_attn_layers}, mlp_skip: {skip_mlp_layers}, Accept Rate: {accept_rate}")
# Assume we want to maximize `accept_rate`
return accept_rate
if __name__ == "__main__":
pretrained_model_name_or_path = "../models/google--gemma-2-2b-it/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = LayerSkipGemma2ForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)
# Init
skip_layer_ids = {
"attn": [],
"mlp": [],
}
model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
# Create
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=50)
print("The best params:", study.best_params)
print("The best accept_rate:", study.best_value)
Output:
[I 2024-11-13 13:26:42,633] Trial 49 finished with value: 0.04390243902439024 and parameters: {'skip_attn_layer_0': 1, 'skip_attn_layer_1': 0, 'skip_attn_layer_2': 1, 'skip_attn_layer_3': 1, 'skip_attn_layer_4': 0, 'skip_attn_layer_5': 0, 'skip_attn_layer_6': 1, 'skip_attn_layer_7': 0, 'skip_attn_layer_8': 1, 'skip_attn_layer_9': 1, 'skip_attn_layer_10': 0, 'skip_attn_layer_11': 0, 'skip_attn_layer_12': 0, 'skip_attn_layer_13': 0, 'skip_attn_layer_14': 1, 'skip_attn_layer_15': 1, 'skip_attn_layer_16': 1, 'skip_attn_layer_17': 0, 'skip_attn_layer_18': 0, 'skip_attn_layer_19': 1, 'skip_attn_layer_20': 1, 'skip_attn_layer_21': 0, 'skip_attn_layer_22': 0, 'skip_attn_layer_23': 0, 'skip_attn_layer_24': 0, 'skip_attn_layer_25': 1, 'skip_mlp_layer_0': 0, 'skip_mlp_layer_1': 0, 'skip_mlp_layer_2': 0, 'skip_mlp_layer_3': 0, 'skip_mlp_layer_4': 1, 'skip_mlp_layer_5': 0, 'skip_mlp_layer_6': 0, 'skip_mlp_layer_7': 1, 'skip_mlp_layer_8': 1, 'skip_mlp_layer_9': 1, 'skip_mlp_layer_10': 1, 'skip_mlp_layer_11': 0, 'skip_mlp_layer_12': 0, 'skip_mlp_layer_13': 0, 'skip_mlp_layer_14': 1, 'skip_mlp_layer_15': 0, 'skip_mlp_layer_16': 1, 'skip_mlp_layer_17': 0, 'skip_mlp_layer_18': 0, 'skip_mlp_layer_19': 1, 'skip_mlp_layer_20': 1, 'skip_mlp_layer_21': 1, 'skip_mlp_layer_22': 0, 'skip_mlp_layer_23': 0, 'skip_mlp_layer_24': 0, 'skip_mlp_layer_25': 1}. Best is trial 24 with value: 0.15.
The best params: {'skip_attn_layer_0': 0, 'skip_attn_layer_1': 0, 'skip_attn_layer_2': 0, 'skip_attn_layer_3': 0, 'skip_attn_layer_4': 1, 'skip_attn_layer_5': 0, 'skip_attn_layer_6': 0, 'skip_attn_layer_7': 0, 'skip_attn_layer_8': 0, 'skip_attn_layer_9': 1, 'skip_attn_layer_10': 0, 'skip_attn_layer_11': 0, 'skip_attn_layer_12': 0, 'skip_attn_layer_13': 0, 'skip_attn_layer_14': 0, 'skip_attn_layer_15': 1, 'skip_attn_layer_16': 0, 'skip_attn_layer_17': 1, 'skip_attn_layer_18': 1, 'skip_attn_layer_19': 1, 'skip_attn_layer_20': 1, 'skip_attn_layer_21': 0, 'skip_attn_layer_22': 0, 'skip_attn_layer_23': 0, 'skip_attn_layer_24': 0, 'skip_attn_layer_25': 1, 'skip_mlp_layer_0': 0, 'skip_mlp_layer_1': 0, 'skip_mlp_layer_2': 1, 'skip_mlp_layer_3': 0, 'skip_mlp_layer_4': 0, 'skip_mlp_layer_5': 0, 'skip_mlp_layer_6': 0, 'skip_mlp_layer_7': 1, 'skip_mlp_layer_8': 1, 'skip_mlp_layer_9': 1, 'skip_mlp_layer_10': 1, 'skip_mlp_layer_11': 0, 'skip_mlp_layer_12': 0, 'skip_mlp_layer_13': 0, 'skip_mlp_layer_14': 1, 'skip_mlp_layer_15': 0, 'skip_mlp_layer_16': 1, 'skip_mlp_layer_17': 0, 'skip_mlp_layer_18': 0, 'skip_mlp_layer_19': 1, 'skip_mlp_layer_20': 1, 'skip_mlp_layer_21': 1, 'skip_mlp_layer_22': 0, 'skip_mlp_layer_23': 0, 'skip_mlp_layer_24': 0, 'skip_mlp_layer_25': 1}
The best accept_rate: 0.15
The final test results were disappointing, with an acceptance rate of only 0.15. However, based on my preliminary tests, larger models tend to tolerate layer skipping better.
References
- dilab-zju/self-speculative-decoding: Code associated with ...
- Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding