Last Updated on 2024-11-19 by Clay
Over the past week, I dedicated some time to reproducing the Self-Speculative Decoding mechanism based on the ideas from the paper Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding, implementing the following modules:
- A Decoder-only Transformer model with layer skipping (based on Llama and Gemma-2 architectures)
- Adaptive Draft Exit Mechanism
- Bayesian Optimization to discover the best layer-skipping strategy (optimizing draft model configurations)
- Self-Speculative Decoding — achieving acceleration purely through the model itself
The following sections elaborate on each module. For a brief overview of the paper, you can refer to my previous notes: [Paper Reading] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding
Layer-Skipping Decoder-only Transformer Model
In the HuggingFace transformers
library, the model implementation already includes the layer index being passed during decoding (mainly for storing each layer’s KV Cache). I added a new attribute skip_layer_ids
to the model structure. When a layer index is listed in skip_layer_ids
, that layer’s computation is skipped when the draft mode is enabled.
Below is my modification to the Gemma-2 model architecture, based on the updated version of the transformers
library.
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, HybridCache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import Gemma2Config
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Gemma2PreTrainedModel(PreTrainedModel):
config_class = Gemma2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Gemma2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = False
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
"""
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
"""
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
# if using the default path -> swap sdpa by eager
if not hard_check_only and config._attn_implementation == "sdpa":
config._attn_implementation = "eager"
return config
class Gemma2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Gemma2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = config.query_pre_attn_scalar**-0.5
if self.hidden_size % self.num_heads != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
self.rotary_emb = Gemma2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Gemma2SdpaAttention(Gemma2Attention):
"""
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Gemma2Attention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
GEMMA2_ATTENTION_CLASSES = {
"eager": Gemma2Attention,
"sdpa": Gemma2SdpaAttention,
}
class Gemma2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Gemma2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class LayerSkipGemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
# Set skip layer
self.draft_mode = False
self.skip_layer_ids = {"attn": [], "mlp": []}
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.config = config
self.is_sliding = not bool(layer_idx % 2)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
self.skip_layer_ids = skip_layer_ids
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states
if self.draft_mode and self.layer_idx in self.skip_layer_ids["attn"]:
hidden_states = residual
self_attn_weights = None
present_key_value = None
else:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
if self.draft_mode and self.layer_idx in self.skip_layer_ids["mlp"]:
hidden_states = residual
else:
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class LayerSkipGemma2Model(Gemma2PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`]
Args:
config: Gemma2Config
"""
def __init__(self, config: Gemma2Config):
super().__init__(config)
# Set skip layer
self.draft_mode = False
self.skip_layer_ids = {"attn": [], "mlp": []}
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LayerSkipGemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
self.skip_layer_ids = skip_layer_ids
for layer in self.layers:
layer.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
for layer in self.layers:
layer.set_draft_mode(_mode=_mode)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
# normalized
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = past_key_values if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: HybridCache,
output_attentions: bool,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class LayerSkipGemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
# Set skip layer
self.draft_mode = False
self.skip_layer_ids = {"attn": [], "mlp": []}
self.model = LayerSkipGemma2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
assert "attn" in skip_layer_ids and "mlp" in skip_layer_ids, "`skip_layer_ids` need to be set `attn` and `mlp`!"
assert isinstance(skip_layer_ids["attn"], list), "`skip_layer_ids['attn']` need to be a list!"
assert isinstance(skip_layer_ids["mlp"], list), "`skip_layer_ids['mlp']` need to be a list!"
for attn_layer_idx in skip_layer_ids["attn"]:
assert attn_layer_idx < len(self.model.layers), f"attn_layer_idx {attn_layer_idx} is out of Range ({len(self.model.layers)})"
for mlp_layer_idx in skip_layer_ids["mlp"]:
assert mlp_layer_idx < len(self.model.layers), f"mlp_layer_idx {mlp_layer_idx} is out of Range ({len(self.model.layers)})"
self.skip_layer_ids = skip_layer_ids
self.model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
print("skip_layer_ids:", self.skip_layer_ids)
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
self.model.set_draft_mode(_mode=_mode)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=None,
**kwargs,
):
# Overwritten: has a special cache type, `HybridCache`
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
)
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
Adaptive Draft Exit Mechanism
For a detailed explanation, refer to my previous notes: [Paper Review] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding. I won’t repeat the details here.
In simple terms, this mechanism dynamically controls how many tokens the draft model can predict. If the draft model lacks confidence and cannot meet the current acceptance threshold, the confidence score threshold for decoding is raised. Conversely, if the acceptance rate is sufficient, the threshold is lowered, allowing the draft model to continue even with lower confidence.
It’s akin to adjusting a child’s screen time based on their test scores. If they perform poorly (low acceptance rate), screen time is reduced (higher decoding confidence threshold); if they do well, they get more screen time.
Here’s the implementation:
import torch
class AdaptiveDraftExitAduster:
def __init__(
self,
target_matchness: float = 0.9,
beta1: float = 0.5,
beta2: float = 0.9,
epsilon: float = 0.01,
max_step_draft: int = 8,
):
"""Initialize DraftExitingAdjuster parameters
:param target_matchness: matching degree target value
:param beta1: sliding average coefficient for matching degree update
:param beta2: th_stop_draft smooth update coefficient
:param epsilon: Adjust the step size each time
:param max_step_draft: The maximum number of steps for draft generation
"""
self.save_init_status(
target_matchness=target_matchness,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
max_step_draft=max_step_draft,
)
self.reset()
def save_init_status(
self,
target_matchness: float,
beta1: float,
beta2: float,
epsilon: float,
max_step_draft: int,
) -> None:
self.init_target_matchness = target_matchness
self.init_beta1 = beta1
self.init_beta2 = beta2
self.init_epsilon = epsilon
self.init_max_step_draft = max_step_draft
def reset(self) -> None:
self.target_matchness = self.init_target_matchness
self.beta1 = self.init_beta1
self.beta2 = self.init_beta2
self.epsilon = self.init_epsilon
self.max_step_draft = self.init_max_step_draft
# Dynamic status
self.curr_matchness = 0.0
self.stop_draft_threshold = 0.5
self.step_num = 0
def update(self, num_matched_tokens, num_drafted_tokens) -> None:
# Update matchness
matchness = num_matched_tokens / num_drafted_tokens
self.curr_matchness = self.beta1 * self.curr_matchness + (1 - self.beta1) * matchness
# Calculate new exit threshold
if num_drafted_tokens == self.max_step_draft:
new_stop_draft_threshold = self.stop_draft_threshold
elif self.curr_matchness <= self.target_matchness:
new_stop_draft_threshold = self.stop_draft_threshold + self.epsilon
else:
new_stop_draft_threshold = self.stop_draft_threshold - self.epsilon
self.stop_draft_threshold = self.beta2 * self.stop_draft_threshold + (1 - self.beta2) * new_stop_draft_threshold
self.step_num += 1
def should_exit(self, draft_prob: torch.FloatTensor) -> bool:
return draft_prob < self.stop_draft_threshold
def get_state(self):
return {
"curr_matchness": self.curr_matchness,
"stop_draft_threshold": self.stop_draft_threshold,
"step_num": self.step_num
}
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
Bayesian Optimization for Layer-Skipping Strategy
I previously wrote about using Bayesian Optimization to search for optimal layer-skipping strategies: Optimizing LayerSkip Models with Bayesian Search for an Effective Layer Skipping Strategy While the initial version was incomplete, I have refined it further to serve as a foundation for experiments.
Below is the class implementation of my Bayesian optimization searcher:
from typing import Dict, List, Callable, Optional
import gc
import copy
import json
import time
import optuna
import torch
from tqdm import tqdm
from utils.utils import AdaptiveDraftExitAduster
class LayerSkipStrategySearcher:
def __init__(
self,
model: torch.nn.Module,
tokenizer,
device,
drafter_speculative_decode: Callable,
target_speculative_decode: Callable,
adjuster: Optional[AdaptiveDraftExitAduster] = None,
) -> None:
self.model = model
if not hasattr(self.model, "set_draft_mode"):
raise TypeError(f"{type(self.model)} is not a LayerSkip-liked model.")
self.tokenizer = tokenizer
self.device = device
# Initialize skip layers
self.reset_skip_layers()
# Sampling methods
self.drafter_speculative_decode = drafter_speculative_decode
self.target_speculative_decode = target_speculative_decode
# Threshold adjuster
self.adjuster = adjuster
# Cache for processed samples
self._processed_samples = None
def reset_skip_layers(self):
"""Reset skip layers to initial state"""
self.skip_layer_ids = {
"attn": [],
"mlp": [],
}
@property
def processed_samples(self):
"""Lazy loading of processed samples with proper cleanup"""
if self._processed_samples is None:
raise ValueError("Samples not initialized. Call prepare_samples first.")
return self._processed_samples
def prepare_samples(self, samples: List[List[Dict[str, str]]]):
"""Process and prepare samples with proper cleanup"""
# Clear any existing processed samples
if self._processed_samples is not None:
del self._processed_samples
self._processed_samples = None
torch.cuda.empty_cache()
gc.collect()
# Process new samples
self._processed_samples = [
self.tokenizer(
self.tokenizer.apply_chat_template(messages, tokenize=False),
return_tensors="pt",
).to(self.device)
for messages in samples
]
def cleanup(self):
"""Clean up resources"""
del self._processed_samples
self._processed_samples = None
torch.cuda.empty_cache()
gc.collect()
def optimize_acceptance_rate(
self,
samples: List[List[Dict[str, str]]],
n_trials: int = 50,
num_hidden_layers: int = 1,
) -> Dict:
try:
# Prepare samples
self.prepare_samples(samples)
# Define search space
num_hidden_layers = getattr(self.model.config, "num_hidden_layers", num_hidden_layers)
print(f"Total layers we can skip: {num_hidden_layers}")
study = optuna.create_study(direction="maximize")
study.optimize(
lambda trial: self.objective_acceptance_rate(
trial=trial,
num_hidden_layers=num_hidden_layers,
),
n_trials=n_trials,
callbacks=[
lambda study, trial: print(f"Trial {trial.number}: {trial.value}"),
lambda study, trial: gc.collect(), # Force garbage collection after each trial
lambda study, trial: torch.cuda.empty_cache() # Clear CUDA cache after each trial
]
)
return {
"best_params": study.best_params,
"best_value": study.best_value
}
finally:
self.cleanup()
def optimize_speculative_speed(
self,
samples: List[List[Dict[str, str]]],
n_trials: int = 50,
num_hidden_layers: int = 1,
) -> Dict:
try:
# Prepare samples
self.prepare_samples(samples)
# Define search space
num_hidden_layers = getattr(self.model.config, "num_hidden_layers", num_hidden_layers)
print(f"Total layers we can skip: {num_hidden_layers}")
study = optuna.create_study(direction="maximize")
study.optimize(
lambda trial: self.objective_speculative_speed(
trial=trial,
num_hidden_layers=num_hidden_layers,
),
n_trials=n_trials,
callbacks=[
lambda study, trial: print(f"Trial {trial.number}: {trial.value}"),
lambda study, trial: gc.collect(),
lambda study, trial: torch.cuda.empty_cache()
]
)
# Save best skip_layer_ids
best_skip_layers = self._get_skip_layers(study.best_trial, num_hidden_layers)
with open("skip_layer_ids.json", "w") as f:
json.dump(best_skip_layers, f)
return {
"best_params": study.best_params,
"best_value": study.best_value
}
finally:
self.cleanup()
def _run_inference(self, inputs, gamma=5, max_new_tokens=100):
"""Common inference logic for both objectives"""
is_end = False
trial_inputs = copy.deepcopy(inputs)
raw_token_num = trial_inputs["input_ids"].shape[1]
total_draft_tokens = 0
total_accept_tokens = 0
generated_tokens = 0
while not is_end:
# Draft model inference
with torch.no_grad(): # Ensure no gradients are tracked
target_inputs, draft_probs, real_generated_tokens = self.drafter_speculative_decode(
draft_model=self.model,
draft_tokenizer=self.tokenizer,
inputs=trial_inputs,
gamma=gamma,
confidence_threshold_adjuster=self.adjuster,
)
total_draft_tokens += real_generated_tokens if real_generated_tokens is not None else gamma
# Target model inference
with torch.no_grad():
outputs, is_end, accept_tokens = self.target_speculative_decode(
target_model=self.model,
target_tokenizer=self.tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
if self.adjuster:
self.adjuster.update(
num_matched_tokens=accept_tokens,
num_drafted_tokens=real_generated_tokens if real_generated_tokens is not None else gamma,
)
total_accept_tokens += accept_tokens
trial_inputs = outputs
generated_tokens = trial_inputs["input_ids"].shape[1] - raw_token_num
if generated_tokens >= max_new_tokens:
break
# Free memory
del target_inputs, draft_probs, trial_inputs
torch.cuda.empty_cache()
return total_accept_tokens, total_draft_tokens, generated_tokens
def objective_speculative_speed(self, trial, num_hidden_layers: int):
try:
skip_layers = self._get_skip_layers(trial, num_hidden_layers)
if not skip_layers:
raise optuna.TrialPruned()
self.model.set_skip_layer_ids(skip_layer_ids=skip_layers)
if self.adjuster:
self.adjuster.reset()
start_time = time.time()
total_generated_tokens = 0
for inputs in tqdm(self.processed_samples):
_, _, generated_tokens = self._run_inference(inputs)
total_generated_tokens += generated_tokens
token_per_second = total_generated_tokens / (time.time() - start_time)
print(f"attn_skip: {skip_layers['attn']}, mlp_skip: {skip_layers['mlp']}, tokens/sec: {token_per_second}")
return token_per_second
except Exception as e:
print(f"Trial failed with error: {str(e)}")
raise optuna.TrialPruned()
def objective_acceptance_rate(self, trial, num_hidden_layers: int):
try:
skip_layers = self._get_skip_layers(trial, num_hidden_layers)
if not skip_layers:
raise optuna.TrialPruned()
self.model.set_skip_layer_ids(skip_layer_ids=skip_layers)
total_accept_tokens_group = 0
total_draft_tokens_group = 0
for inputs in tqdm(self.processed_samples):
accept_tokens, draft_tokens, _ = self._run_inference(inputs)
total_accept_tokens_group += accept_tokens
total_draft_tokens_group += draft_tokens
accept_rate = total_accept_tokens_group / total_draft_tokens_group
print(f"attn_skip: {skip_layers['attn']}, mlp_skip: {skip_layers['mlp']}, Accept Rate: {accept_rate}")
return accept_rate
except Exception as e:
print(f"Trial failed with error: {str(e)}")
raise optuna.TrialPruned()
def _get_skip_layers(self, trial, num_hidden_layers: int) -> Dict[str, List[int]]:
"""Helper method to get skip layers configuration from trial"""
skip_attn_layers = [i for i in range(num_hidden_layers) if trial.suggest_int(f"skip_attn_layer_{i}", 0, 1) == 1]
skip_mlp_layers = [i for i in range(num_hidden_layers) if trial.suggest_int(f"skip_mlp_layer_{i}", 0, 1) == 1]
if not skip_attn_layers and not skip_mlp_layers:
return None
return {
"attn": skip_attn_layers,
"mlp": skip_mlp_layers,
}
Now it’s time to conduct real experiments:
import optuna
from typing import Dict, Optional, Tuple
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import copy
import time
from datasets import load_dataset
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from layerskip_modeling.modeling_layerskip_gemma2 import LayerSkipGemma2ForCausalLM
from sampling.sampling import sample_next_token
from utils.utils import calculate_continuous_acceptance, AdaptiveDraftExitAduster
from utils.optimization_searcher import LayerSkipStrategySearcher
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
draft_mode: bool = True,
confidence_threshold_adjuster: Optional[AdaptiveDraftExitAduster] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_model.set_draft_mode(draft_mode)
draft_probs = []
real_generated_tokens = 0
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
real_generated_tokens += 1
# Early exit
if confidence_threshold_adjuster and confidence_threshold_adjuster.should_exit(draft_prob=probs[0, 0, next_tokens.item()]):
# print(confidence_threshold_adjuster.get_state())
break
draft_model.set_draft_mode(False)
return inputs, torch.cat(draft_probs, dim=1), real_generated_tokens
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
target_model.set_draft_mode(False)
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
if __name__ == "__main__":
# Load tokenizer and model
pretrained_model_name_or_path = "../models/google--gemma-2-2b-it/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = LayerSkipGemma2ForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)
# Load dataset
dataset = load_dataset("shibing624/sharegpt_gpt4")
samples = dataset["train"]["conversations"][:10]
samples = [[{"role": sample[0]["from"].replace("human", "user").replace("gpt", "assistant"), "content": sample[0]["value"]}] for sample in samples]
# Adaptive draft threshold adjuster
confidence_threshold_adjuster = AdaptiveDraftExitAduster(
target_matchness=0.5,
beta1=0.5,
beta2=0.9,
epsilon=0.1,
max_step_draft=8,
)
# Searcher
searcher = LayerSkipStrategySearcher(
model=model,
tokenizer=tokenizer,
device=device,
drafter_speculative_decode=drafter_speculative_decode,
target_speculative_decode=target_speculative_decode,
adjuster=confidence_threshold_adjuster,
)
searcher.optimize_speculative_speed(
samples=samples,
n_trials=1000,
)
Output:
In my experiments, the best result came from the 846th iteration, achieving a decoding speed of 24.816 tokens/sec.
If layers are skipped arbitrarily, the decoding speed fluctuates around 16 – 18 tokens/sec. This highlights the value of using Bayesian Optimization to fine-tune layer-skipping hyperparameters.
However, overall, my experiment was unsuccessful. Why? As you’ll see in the next section, the decoding speed of the original model without Self-Speculative Decoding is already around 29.647 tokens/sec.
Self-Speculative Decoding Results
Initially, I experimented with the gemma-2-2b-it architecture. However, skipping more than six layers caused slowdowns, consistent with the paper’s findings that smaller models hit bottlenecks sooner.
Experiments with gemma-2-9b-it showed slightly better results, but still failed to achieve acceleration. Below are the results for Self-Speculative Decoding, the standard gemma-2-9b-it decoding speed, and the decoding speed after layer skipping:
Generate token number: 102
Generate speed: 24.07665693982178 token/sec
Speculative Decoding Spent Time: 4.236468553543091 seconds.
Accept Rate: 0.4056603773584906
Generate token number: 100
Generate speed: 29.647649721089135 token/sec
Normal Target Model Decoding Spent Time: 3.37294864654541 seconds.
Generate token number: 100
Generate speed: 48.81880782264111 token/sec
Normal Draft Model Decoding Spent Time: 2.0483908653259277 seconds.
This demonstrates that even with extensive layer skipping, Self-Speculative Decoding fails to double the speed compared to the original model. The relationship between decoding time and acceptance rate can be summarized in the following table (time units converted to 10 ms for clarity):
Model State | Original Model | Layer-Skipped Model | Time Needed for Acceleration (Assuming Full Acceptance + Extra Model Inference Time) |
---|---|---|---|
Time for Token 1 | 3.4 | 2 | 2 + 3.4 = 5.4 |
Time for Token 2 | 6.8 | 4 | 4 + 3.4 = 7.4 |
Time for Token 3 | 10.2 | 6 | 6 + 3.4 = 9.4 |
Time for Token 4 | 13.6 | 8 | 8 + 3.4 = 11.4 |
Time for Token 5 | 17.0 | 10 | 10 + 3.4 = 13 .4 |
Thus, for the draft model to produce five tokens, the acceptance rate must be at least 0.8 to outperform the original model—a stringent requirement.
Conversely, larger models are more likely to benefit from Self-Speculative Decoding because their baseline inference speeds are slower, and skipping layers has a smaller relative impact on performance. This increases the chance of finding a “minimal loss + significant cost reduction” sweet spot.
Alternatively, training approaches like Kangaroo or Meta AI’s LayerSkip can further improve layer-skipping performance, allowing more aggressive acceleration with minimal loss.
Frankly, I’m eager to improve these results. My next step is to adopt Kangaroo’s training approach to enhance Self-Speculative Decoding without making substantial architectural changes.
Lastly, feel free to follow my GitHub: https://github.com/ccs96307/fast-llm-inference. I’d appreciate any feedback or suggestions!
References
- Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting
- LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding