Skip to content

Self-Speculative Decoding Implementation: LayerSkip Model, Bayesian Optimization, and Adaptive Draft-Exiting Mechanism (Here are gemma-2-9b-it Experiment Results)

Last Updated on 2024-11-19 by Clay

Over the past week, I dedicated some time to reproducing the Self-Speculative Decoding mechanism based on the ideas from the paper Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding, implementing the following modules:

  • A Decoder-only Transformer model with layer skipping (based on Llama and Gemma-2 architectures)
  • Adaptive Draft Exit Mechanism
  • Bayesian Optimization to discover the best layer-skipping strategy (optimizing draft model configurations)
  • Self-Speculative Decoding — achieving acceleration purely through the model itself

The following sections elaborate on each module. For a brief overview of the paper, you can refer to my previous notes: [Paper Reading] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding


Layer-Skipping Decoder-only Transformer Model

In the HuggingFace transformers library, the model implementation already includes the layer index being passed during decoding (mainly for storing each layer’s KV Cache). I added a new attribute skip_layer_ids to the model structure. When a layer index is listed in skip_layer_ids, that layer’s computation is skipped when the draft mode is enabled.

Below is my modification to the Gemma-2 model architecture, based on the updated version of the transformers library.

# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.utils.checkpoint

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, HybridCache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel

from transformers import Gemma2Config


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class Gemma2PreTrainedModel(PreTrainedModel):
    config_class = Gemma2Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Gemma2DecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True
    _supports_quantized_cache = False
    _supports_static_cache = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    @classmethod
    def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
        """
        Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
        SDPA reduces the model performance on Gemma2 because of the logits softcapping.
        """
        config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)

        # if using the default path -> swap sdpa by eager
        if not hard_check_only and config._attn_implementation == "sdpa":
            config._attn_implementation = "eager"

        return config
    

class Gemma2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
        self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

    @torch.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        self.inv_freq.to(x.device)
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
    

class Gemma2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.scaling = config.query_pre_attn_scalar**-0.5

        if self.hidden_size % self.num_heads != 0:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
        self.rotary_emb = Gemma2RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

        if self.config.attn_logit_softcapping is not None:
            attn_weights = attn_weights / self.config.attn_logit_softcapping
            attn_weights = torch.tanh(attn_weights)
            attn_weights = attn_weights * self.config.attn_logit_softcapping
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
    

class Gemma2SdpaAttention(Gemma2Attention):
    """
    Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from Gemma2Attention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
            scale=self.scaling,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value
    

GEMMA2_ATTENTION_CLASSES = {
    "eager": Gemma2Attention,
    "sdpa": Gemma2SdpaAttention,
}


class Gemma2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_activation]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    

class Gemma2RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"
    

class LayerSkipGemma2DecoderLayer(nn.Module):
    def __init__(self, config: Gemma2Config, layer_idx: int):
        super().__init__()

        # Set skip layer
        self.draft_mode = False
        self.skip_layer_ids = {"attn": [], "mlp": []}
        self.layer_idx = layer_idx

        self.hidden_size = config.hidden_size
        self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
        self.mlp = Gemma2MLP(config)
        self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.config = config
        self.is_sliding = not bool(layer_idx % 2)
        self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.sliding_window = config.sliding_window
        self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        self.skip_layer_ids = skip_layer_ids

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence
            kwargs (`dict`, *optional*):
                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
                into the model
        """
        if self.is_sliding and attention_mask is not None:  # efficient SDPA and no padding
            # Flash-attn is a 2D tensor
            if self.config._attn_implementation == "flash_attention_2":
                if past_key_value is not None:  # when decoding
                    attention_mask = attention_mask[:, -self.sliding_window :]
            else:
                min_dtype = torch.finfo(hidden_states.dtype).min
                sliding_window_mask = torch.tril(
                    torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
                )
                attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
                if attention_mask.shape[-1] <= 1:  # when decoding
                    attention_mask = attention_mask[:, :, :, -self.sliding_window :]

        residual = hidden_states

        if self.draft_mode and self.layer_idx in self.skip_layer_ids["attn"]:
            hidden_states = residual
            self_attn_weights = None
            present_key_value = None
        else:
            hidden_states = self.input_layernorm(hidden_states)

            # Self Attention
            hidden_states, self_attn_weights, present_key_value = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states


        if self.draft_mode and self.layer_idx in self.skip_layer_ids["mlp"]:
            hidden_states = residual
        else:
            hidden_states = self.pre_feedforward_layernorm(hidden_states)
            hidden_states = self.mlp(hidden_states)
            hidden_states = self.post_feedforward_layernorm(hidden_states)
            hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs
    

class LayerSkipGemma2Model(Gemma2PreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`]

    Args:
        config: Gemma2Config
    """

    def __init__(self, config: Gemma2Config):
        super().__init__(config)

        # Set skip layer
        self.draft_mode = False
        self.skip_layer_ids = {"attn": [], "mlp": []}

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LayerSkipGemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        self.skip_layer_ids = skip_layer_ids

        for layer in self.layers:
            layer.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode

        for layer in self.layers:
            layer.set_draft_mode(_mode=_mode)

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[HybridCache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None and not self.training:
            batch_size, seq_len, _ = inputs_embeds.shape
            past_key_values = HybridCache(
                self.config,
                batch_size=batch_size,
                max_cache_len=seq_len,
                device=self.device,
                dtype=inputs_embeds.dtype,
            )

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # embed positions
        hidden_states = inputs_embeds

        # normalized
        # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
        hidden_states = hidden_states * normalizer

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = past_key_values if use_cache else None

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: HybridCache,
        output_attentions: bool,
    ):
        # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
        # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
        # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
        # as it doesn't cause dynamic control issues.
        if self.config._attn_implementation == "flash_attention_2":
            return attention_mask

        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]
        if isinstance(past_key_values, HybridCache):
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )
        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        """
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            device (`torch.device`):
                The device to plcae the 4D attention mask on.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        """
        if attention_mask is not None and attention_mask.dim() == 4:
            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
            causal_mask = attention_mask
        else:
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask


class LayerSkipGemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)

        # Set skip layer
        self.draft_mode = False
        self.skip_layer_ids = {"attn": [], "mlp": []}

        self.model = LayerSkipGemma2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        assert "attn" in skip_layer_ids and "mlp" in skip_layer_ids, "`skip_layer_ids` need to be set `attn` and `mlp`!"
        assert isinstance(skip_layer_ids["attn"], list), "`skip_layer_ids['attn']` need to be a list!"
        assert isinstance(skip_layer_ids["mlp"], list), "`skip_layer_ids['mlp']` need to be a list!"

        for attn_layer_idx in skip_layer_ids["attn"]:
            assert attn_layer_idx < len(self.model.layers), f"attn_layer_idx {attn_layer_idx} is out of Range ({len(self.model.layers)})" 
            
        for mlp_layer_idx in skip_layer_ids["mlp"]:
            assert mlp_layer_idx < len(self.model.layers), f"mlp_layer_idx {mlp_layer_idx} is out of Range ({len(self.model.layers)})"

        self.skip_layer_ids = skip_layer_ids
        self.model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

        print("skip_layer_ids:", self.skip_layer_ids)

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode
        self.model.set_draft_mode(_mode=_mode)

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[HybridCache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
        if self.config.final_logit_softcapping is not None:
            logits = logits / self.config.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.config.final_logit_softcapping

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        num_logits_to_keep=None,
        **kwargs,
    ):
        # Overwritten: has a special cache type, `HybridCache`

        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
        # Exception 1: when passing input_embeds, input_ids may be missing entries
        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
        if past_key_values is not None:
            if inputs_embeds is not None:  # Exception 1
                input_ids = input_ids[:, -cache_position.shape[0] :]
            elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
                input_ids = input_ids[:, cache_position]
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]
                # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
                # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
                # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
                # batch size = 1 case, `position_ids` is already contiguous but with varying stride
                # which retriggers a capture.
                position_ids = position_ids.clone(memory_format=torch.contiguous_format)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and cache_position[0] == 0:
            model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
        else:
            # The clone here is for the same reason as for `position_ids`.
            model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

        if (
            isinstance(past_key_values, HybridCache)
            and attention_mask.ndim == 2
            and not self.config._attn_implementation == "flash_attention_2"
        ):
            if model_inputs["inputs_embeds"] is not None:
                batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
                device = model_inputs["inputs_embeds"].device
            else:
                batch_size, sequence_length = model_inputs["input_ids"].shape
                device = model_inputs["input_ids"].device

            attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
                attention_mask,
                sequence_length=sequence_length,
                target_length=past_key_values.get_max_cache_shape(),
                dtype=self.lm_head.weight.dtype,
                device=device,
                cache_position=cache_position,
                batch_size=batch_size,
            )

        if num_logits_to_keep is not None:
            model_inputs["num_logits_to_keep"] = num_logits_to_keep

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

Adaptive Draft Exit Mechanism

For a detailed explanation, refer to my previous notes: [Paper Review] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding. I won’t repeat the details here.

In simple terms, this mechanism dynamically controls how many tokens the draft model can predict. If the draft model lacks confidence and cannot meet the current acceptance threshold, the confidence score threshold for decoding is raised. Conversely, if the acceptance rate is sufficient, the threshold is lowered, allowing the draft model to continue even with lower confidence.

It’s akin to adjusting a child's screen time based on their test scores. If they perform poorly (low acceptance rate), screen time is reduced (higher decoding confidence threshold); if they do well, they get more screen time.

Here’s the implementation:

import torch


class AdaptiveDraftExitAduster:
    def __init__(
        self,
        target_matchness: float = 0.9,
        beta1: float = 0.5,
        beta2: float = 0.9,
        epsilon: float = 0.01,
        max_step_draft: int = 8,
    ):
        """Initialize DraftExitingAdjuster parameters
        :param target_matchness: matching degree target value
        :param beta1: sliding average coefficient for matching degree update
        :param beta2: th_stop_draft smooth update coefficient
        :param epsilon: Adjust the step size each time
        :param max_step_draft: The maximum number of steps for draft generation
        """

        self.save_init_status(
            target_matchness=target_matchness,
            beta1=beta1,
            beta2=beta2,
            epsilon=epsilon,
            max_step_draft=max_step_draft,
        )
        self.reset()

    def save_init_status(
        self,
        target_matchness: float,
        beta1: float,
        beta2: float,
        epsilon: float,
        max_step_draft: int,
    ) -> None:
        self.init_target_matchness = target_matchness
        self.init_beta1 = beta1
        self.init_beta2 = beta2
        self.init_epsilon = epsilon
        self.init_max_step_draft = max_step_draft

    def reset(self) -> None:
        self.target_matchness = self.init_target_matchness
        self.beta1 = self.init_beta1
        self.beta2 = self.init_beta2
        self.epsilon = self.init_epsilon
        self.max_step_draft = self.init_max_step_draft

        # Dynamic status
        self.curr_matchness = 0.0
        self.stop_draft_threshold = 0.5
        self.step_num = 0

    def update(self, num_matched_tokens, num_drafted_tokens) -> None:
        # Update matchness
        matchness = num_matched_tokens / num_drafted_tokens
        self.curr_matchness = self.beta1 * self.curr_matchness + (1 - self.beta1) * matchness

        # Calculate new exit threshold
        if num_drafted_tokens == self.max_step_draft:
            new_stop_draft_threshold = self.stop_draft_threshold

        elif self.curr_matchness <= self.target_matchness:
            new_stop_draft_threshold = self.stop_draft_threshold + self.epsilon
        else:
            new_stop_draft_threshold = self.stop_draft_threshold - self.epsilon

        self.stop_draft_threshold = self.beta2 * self.stop_draft_threshold + (1 - self.beta2) * new_stop_draft_threshold
        self.step_num += 1

    def should_exit(self, draft_prob: torch.FloatTensor) -> bool:
        return draft_prob < self.stop_draft_threshold

    def get_state(self):
        return {
            "curr_matchness": self.curr_matchness,
            "stop_draft_threshold": self.stop_draft_threshold,
            "step_num": self.step_num
        }


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance

Bayesian Optimization for Layer-Skipping Strategy

I previously wrote about using Bayesian Optimization to search for optimal layer-skipping strategies: Optimizing LayerSkip Models with Bayesian Search for an Effective Layer Skipping Strategy While the initial version was incomplete, I have refined it further to serve as a foundation for experiments.

Below is the class implementation of my Bayesian optimization searcher:

from typing import Dict, List, Callable, Optional

import gc
import copy
import json
import time

import optuna
import torch
from tqdm import tqdm

from utils.utils import AdaptiveDraftExitAduster


class LayerSkipStrategySearcher:
    def __init__(
        self,
        model: torch.nn.Module,
        tokenizer,
        device,
        drafter_speculative_decode: Callable,
        target_speculative_decode: Callable,
        adjuster: Optional[AdaptiveDraftExitAduster] = None,
    ) -> None:
        self.model = model
        if not hasattr(self.model, "set_draft_mode"):
            raise TypeError(f"{type(self.model)} is not a LayerSkip-liked model.")
        
        self.tokenizer = tokenizer
        self.device = device
        
        # Initialize skip layers
        self.reset_skip_layers()
        
        # Sampling methods
        self.drafter_speculative_decode = drafter_speculative_decode
        self.target_speculative_decode = target_speculative_decode
        
        # Threshold adjuster
        self.adjuster = adjuster
        
        # Cache for processed samples
        self._processed_samples = None

    def reset_skip_layers(self):
        """Reset skip layers to initial state"""
        self.skip_layer_ids = {
            "attn": [],
            "mlp": [],
        }
    
    @property
    def processed_samples(self):
        """Lazy loading of processed samples with proper cleanup"""
        if self._processed_samples is None:
            raise ValueError("Samples not initialized. Call prepare_samples first.")
        return self._processed_samples
    
    def prepare_samples(self, samples: List[List[Dict[str, str]]]):
        """Process and prepare samples with proper cleanup"""
        # Clear any existing processed samples
        if self._processed_samples is not None:
            del self._processed_samples
            self._processed_samples = None
            torch.cuda.empty_cache()
            gc.collect()
        
        # Process new samples
        self._processed_samples = [
            self.tokenizer(
                self.tokenizer.apply_chat_template(messages, tokenize=False),
                return_tensors="pt",
            ).to(self.device)
            for messages in samples
        ]

    def cleanup(self):
        """Clean up resources"""
        del self._processed_samples
        self._processed_samples = None
        torch.cuda.empty_cache()
        gc.collect()

    def optimize_acceptance_rate(
        self,
        samples: List[List[Dict[str, str]]],
        n_trials: int = 50,
        num_hidden_layers: int = 1,
    ) -> Dict:
        try:
            # Prepare samples
            self.prepare_samples(samples)
            
            # Define search space
            num_hidden_layers = getattr(self.model.config, "num_hidden_layers", num_hidden_layers)
            print(f"Total layers we can skip: {num_hidden_layers}")
            
            study = optuna.create_study(direction="maximize")
            study.optimize(
                lambda trial: self.objective_acceptance_rate(
                    trial=trial,
                    num_hidden_layers=num_hidden_layers,
                ),
                n_trials=n_trials,
                callbacks=[
                    lambda study, trial: print(f"Trial {trial.number}: {trial.value}"),
                    lambda study, trial: gc.collect(),  # Force garbage collection after each trial
                    lambda study, trial: torch.cuda.empty_cache()  # Clear CUDA cache after each trial
                ]
            )
            
            return {
                "best_params": study.best_params,
                "best_value": study.best_value
            }
        finally:
            self.cleanup()

    def optimize_speculative_speed(
        self,
        samples: List[List[Dict[str, str]]],
        n_trials: int = 50,
        num_hidden_layers: int = 1,
    ) -> Dict:
        try:
            # Prepare samples
            self.prepare_samples(samples)
            
            # Define search space
            num_hidden_layers = getattr(self.model.config, "num_hidden_layers", num_hidden_layers)
            print(f"Total layers we can skip: {num_hidden_layers}")
            
            study = optuna.create_study(direction="maximize")
            study.optimize(
                lambda trial: self.objective_speculative_speed(
                    trial=trial,
                    num_hidden_layers=num_hidden_layers,
                ),
                n_trials=n_trials,
                callbacks=[
                    lambda study, trial: print(f"Trial {trial.number}: {trial.value}"),
                    lambda study, trial: gc.collect(),
                    lambda study, trial: torch.cuda.empty_cache()
                ]
            )

            # Save best skip_layer_ids
            best_skip_layers = self._get_skip_layers(study.best_trial, num_hidden_layers)
            with open("skip_layer_ids.json", "w") as f:
                json.dump(best_skip_layers, f)
            
            return {
                "best_params": study.best_params,
                "best_value": study.best_value
            }
        finally:
            self.cleanup()

    def _run_inference(self, inputs, gamma=5, max_new_tokens=100):
        """Common inference logic for both objectives"""
        is_end = False

        trial_inputs = copy.deepcopy(inputs)
        raw_token_num = trial_inputs["input_ids"].shape[1]
        total_draft_tokens = 0
        total_accept_tokens = 0
        generated_tokens = 0
        
        while not is_end:
            # Draft model inference
            with torch.no_grad():  # Ensure no gradients are tracked
                target_inputs, draft_probs, real_generated_tokens = self.drafter_speculative_decode(
                    draft_model=self.model,
                    draft_tokenizer=self.tokenizer,
                    inputs=trial_inputs,
                    gamma=gamma,
                    confidence_threshold_adjuster=self.adjuster,
                )
            
            total_draft_tokens += real_generated_tokens if real_generated_tokens is not None else gamma
            
            # Target model inference
            with torch.no_grad():
                outputs, is_end, accept_tokens = self.target_speculative_decode(
                    target_model=self.model,
                    target_tokenizer=self.tokenizer,
                    inputs=target_inputs,
                    draft_probs=draft_probs,
                )
            
            if self.adjuster:
                self.adjuster.update(
                    num_matched_tokens=accept_tokens,
                    num_drafted_tokens=real_generated_tokens if real_generated_tokens is not None else gamma,
                )
            
            total_accept_tokens += accept_tokens
            trial_inputs = outputs
            
            generated_tokens = trial_inputs["input_ids"].shape[1] - raw_token_num
            if generated_tokens >= max_new_tokens:
                break
            
        # Free memory
        del target_inputs, draft_probs, trial_inputs
        torch.cuda.empty_cache()
        
        return total_accept_tokens, total_draft_tokens, generated_tokens

    def objective_speculative_speed(self, trial, num_hidden_layers: int):
        try:
            skip_layers = self._get_skip_layers(trial, num_hidden_layers)
            if not skip_layers:
                raise optuna.TrialPruned()
            
            self.model.set_skip_layer_ids(skip_layer_ids=skip_layers)
            
            if self.adjuster:
                self.adjuster.reset()
            
            start_time = time.time()
            total_generated_tokens = 0
            
            for inputs in tqdm(self.processed_samples):
                _, _, generated_tokens = self._run_inference(inputs)
                total_generated_tokens += generated_tokens
            
            token_per_second = total_generated_tokens / (time.time() - start_time)
            print(f"attn_skip: {skip_layers['attn']}, mlp_skip: {skip_layers['mlp']}, tokens/sec: {token_per_second}")
            
            return token_per_second
        
        except Exception as e:
            print(f"Trial failed with error: {str(e)}")
            raise optuna.TrialPruned()

    def objective_acceptance_rate(self, trial, num_hidden_layers: int):
        try:
            skip_layers = self._get_skip_layers(trial, num_hidden_layers)
            if not skip_layers:
                raise optuna.TrialPruned()
            
            self.model.set_skip_layer_ids(skip_layer_ids=skip_layers)
            
            total_accept_tokens_group = 0
            total_draft_tokens_group = 0
            
            for inputs in tqdm(self.processed_samples):
                accept_tokens, draft_tokens, _ = self._run_inference(inputs)
                total_accept_tokens_group += accept_tokens
                total_draft_tokens_group += draft_tokens
            
            accept_rate = total_accept_tokens_group / total_draft_tokens_group
            print(f"attn_skip: {skip_layers['attn']}, mlp_skip: {skip_layers['mlp']}, Accept Rate: {accept_rate}")
            
            return accept_rate
        
        except Exception as e:
            print(f"Trial failed with error: {str(e)}")
            raise optuna.TrialPruned()

    def _get_skip_layers(self, trial, num_hidden_layers: int) -> Dict[str, List[int]]:
        """Helper method to get skip layers configuration from trial"""
        skip_attn_layers = [i for i in range(num_hidden_layers) if trial.suggest_int(f"skip_attn_layer_{i}", 0, 1) == 1]
        skip_mlp_layers = [i for i in range(num_hidden_layers) if trial.suggest_int(f"skip_mlp_layer_{i}", 0, 1) == 1]
        
        if not skip_attn_layers and not skip_mlp_layers:
            return None
            
        return {
            "attn": skip_attn_layers,
            "mlp": skip_mlp_layers,
        }


Now it’s time to conduct real experiments:

import optuna
from typing import Dict, Optional, Tuple

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import copy
import time

from datasets import load_dataset
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from layerskip_modeling.modeling_layerskip_gemma2 import LayerSkipGemma2ForCausalLM
from sampling.sampling import sample_next_token
from utils.utils import calculate_continuous_acceptance, AdaptiveDraftExitAduster
from utils.optimization_searcher import LayerSkipStrategySearcher


def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
    draft_mode: bool = True,
    confidence_threshold_adjuster: Optional[AdaptiveDraftExitAduster] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_model.set_draft_mode(draft_mode)
    draft_probs = []
    real_generated_tokens = 0

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

        real_generated_tokens += 1

        # Early exit
        if confidence_threshold_adjuster and confidence_threshold_adjuster.should_exit(draft_prob=probs[0, 0, next_tokens.item()]):
            # print(confidence_threshold_adjuster.get_state())
            break

    draft_model.set_draft_mode(False)

    return inputs, torch.cat(draft_probs, dim=1), real_generated_tokens


def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    target_model.set_draft_mode(False)
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


if __name__ == "__main__":
    # Load tokenizer and model
    pretrained_model_name_or_path = "../models/google--gemma-2-2b-it/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    model = LayerSkipGemma2ForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)

    # Load dataset
    dataset = load_dataset("shibing624/sharegpt_gpt4")
    samples = dataset["train"]["conversations"][:10]
    samples = [[{"role": sample[0]["from"].replace("human", "user").replace("gpt", "assistant"), "content": sample[0]["value"]}] for sample in samples]

    # Adaptive draft threshold adjuster
    confidence_threshold_adjuster = AdaptiveDraftExitAduster(
        target_matchness=0.5,
        beta1=0.5,
        beta2=0.9,
        epsilon=0.1,
        max_step_draft=8,
    )

    # Searcher
    searcher = LayerSkipStrategySearcher(
        model=model,
        tokenizer=tokenizer,
        device=device,
        drafter_speculative_decode=drafter_speculative_decode,
        target_speculative_decode=target_speculative_decode,
        adjuster=confidence_threshold_adjuster,
    )

    searcher.optimize_speculative_speed(
        samples=samples,
        n_trials=1000,
    )


Output:

In my experiments, the best result came from the 846th iteration, achieving a decoding speed of 24.816 tokens/sec.

If layers are skipped arbitrarily, the decoding speed fluctuates around 16 - 18 tokens/sec. This highlights the value of using Bayesian Optimization to fine-tune layer-skipping hyperparameters.

However, overall, my experiment was unsuccessful. Why? As you’ll see in the next section, the decoding speed of the original model without Self-Speculative Decoding is already around 29.647 tokens/sec.


Self-Speculative Decoding Results

Initially, I experimented with the gemma-2-2b-it architecture. However, skipping more than six layers caused slowdowns, consistent with the paper’s findings that smaller models hit bottlenecks sooner.

Experiments with gemma-2-9b-it showed slightly better results, but still failed to achieve acceleration. Below are the results for Self-Speculative Decoding, the standard gemma-2-9b-it decoding speed, and the decoding speed after layer skipping:

Generate token number: 102
Generate speed: 24.07665693982178 token/sec
Speculative Decoding Spent Time: 4.236468553543091 seconds.
Accept Rate: 0.4056603773584906

Generate token number: 100
Generate speed: 29.647649721089135 token/sec
Normal Target Model Decoding Spent Time: 3.37294864654541 seconds.

Generate token number: 100
Generate speed: 48.81880782264111 token/sec
Normal Draft Model Decoding Spent Time: 2.0483908653259277 seconds.


This demonstrates that even with extensive layer skipping, Self-Speculative Decoding fails to double the speed compared to the original model. The relationship between decoding time and acceptance rate can be summarized in the following table (time units converted to 10 ms for clarity):

Model StateOriginal ModelLayer-Skipped ModelTime Needed for Acceleration (Assuming Full Acceptance + Extra Model Inference Time)
Time for Token 13.422 + 3.4 = 5.4
Time for Token 26.844 + 3.4 = 7.4
Time for Token 310.266 + 3.4 = 9.4
Time for Token 413.688 + 3.4 = 11.4
Time for Token 517.01010 + 3.4 = 13 .4

Thus, for the draft model to produce five tokens, the acceptance rate must be at least 0.8 to outperform the original model—a stringent requirement.

Conversely, larger models are more likely to benefit from Self-Speculative Decoding because their baseline inference speeds are slower, and skipping layers has a smaller relative impact on performance. This increases the chance of finding a “minimal loss + significant cost reduction” sweet spot.

Alternatively, training approaches like Kangaroo or Meta AI’s LayerSkip can further improve layer-skipping performance, allowing more aggressive acceleration with minimal loss.

Frankly, I’m eager to improve these results. My next step is to adopt Kangaroo’s training approach to enhance Self-Speculative Decoding without making substantial architectural changes.

Lastly, feel free to follow my GitHub: https://github.com/ccs96307/fast-llm-inference. I’d appreciate any feedback or suggestions!


References


Read More

Leave a Reply