Skip to content

Self-Speculative Decoding 完整實作: LayerSkip Model, Bayesian Optimization, and Adaptive Draft-Exiting Mechanism(附 gemma-2-9b-it 實驗結果)

Last Updated on 2024-11-17 by Clay

在過去的一週裡,我抽空按照論文 Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding 的思路嘗試復現了一遍自推測性解碼Self-Speculative Decoding),包含以下模組:

  • 跳層解碼的 Decoder-only Transformer 模型(主要以 Llama 和 Gemma-2 兩種架構為主)
  • 自適應草稿離開機制
  • 貝氏優化探索最佳跳層策略(尋找怎樣的搭配才會是最好的草稿模型)
  • Self-Speculative Decoding —— 完成只靠模型自身的加速

以下我們分別講講這些模組。若是想了解論文的簡要介紹,也可以參考我之前的筆記:[論文閱讀] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding


跳層解碼的 Decoder-only Transformer 模型

由於在 HuggingFace transformers 套件中的模型實作,本就有在解碼層中傳遞該層的編號(畢竟可能需要儲存每一層的 KV Cache),所以我將模型架構多儲存了一個 skip_layer_ids,當該層的編號是儲存在該列表中,即會在草稿模式啟用時跳過該層的計算。

以下是我對 Gemma-2 模型架構的改造,並且是基於比較新的 transformers 版本。

# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.utils.checkpoint

from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, HybridCache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel

from transformers import Gemma2Config


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


class Gemma2PreTrainedModel(PreTrainedModel):
    config_class = Gemma2Config
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["Gemma2DecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True
    _supports_quantized_cache = False
    _supports_static_cache = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    @classmethod
    def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
        """
        Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
        SDPA reduces the model performance on Gemma2 because of the logits softcapping.
        """
        config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)

        # if using the default path -> swap sdpa by eager
        if not hard_check_only and config._attn_implementation == "sdpa":
            config._attn_implementation = "eager"

        return config
    

class Gemma2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
        self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

    @torch.no_grad()
    def forward(self, x, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        self.inv_freq.to(x.device)
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
    

class Gemma2Attention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.is_causal = True
        self.scaling = config.query_pre_attn_scalar**-0.5

        if self.hidden_size % self.num_heads != 0:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
        self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
        self.rotary_emb = Gemma2RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

        if self.config.attn_logit_softcapping is not None:
            attn_weights = attn_weights / self.config.attn_logit_softcapping
            attn_weights = torch.tanh(attn_weights)
            attn_weights = attn_weights * self.config.attn_logit_softcapping
        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.view(bsz, q_len, -1)
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value
    

class Gemma2SdpaAttention(Gemma2Attention):
    """
    Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """

    # Adapted from Gemma2Attention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
        # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
        is_causal = True if causal_mask is None and q_len > 1 else False

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=is_causal,
            scale=self.scaling,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, -1)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value
    

GEMMA2_ATTENTION_CLASSES = {
    "eager": Gemma2Attention,
    "sdpa": Gemma2SdpaAttention,
}


class Gemma2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_activation]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    

class Gemma2RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"
    

class LayerSkipGemma2DecoderLayer(nn.Module):
    def __init__(self, config: Gemma2Config, layer_idx: int):
        super().__init__()

        # Set skip layer
        self.draft_mode = False
        self.skip_layer_ids = {"attn": [], "mlp": []}
        self.layer_idx = layer_idx

        self.hidden_size = config.hidden_size
        self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
        self.mlp = Gemma2MLP(config)
        self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.config = config
        self.is_sliding = not bool(layer_idx % 2)
        self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.sliding_window = config.sliding_window
        self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        self.skip_layer_ids = skip_layer_ids

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence
            kwargs (`dict`, *optional*):
                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
                into the model
        """
        if self.is_sliding and attention_mask is not None:  # efficient SDPA and no padding
            # Flash-attn is a 2D tensor
            if self.config._attn_implementation == "flash_attention_2":
                if past_key_value is not None:  # when decoding
                    attention_mask = attention_mask[:, -self.sliding_window :]
            else:
                min_dtype = torch.finfo(hidden_states.dtype).min
                sliding_window_mask = torch.tril(
                    torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
                )
                attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
                if attention_mask.shape[-1] <= 1:  # when decoding
                    attention_mask = attention_mask[:, :, :, -self.sliding_window :]

        residual = hidden_states

        if self.draft_mode and self.layer_idx in self.skip_layer_ids["attn"]:
            hidden_states = residual
            self_attn_weights = None
            present_key_value = None
        else:
            hidden_states = self.input_layernorm(hidden_states)

            # Self Attention
            hidden_states, self_attn_weights, present_key_value = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states


        if self.draft_mode and self.layer_idx in self.skip_layer_ids["mlp"]:
            hidden_states = residual
        else:
            hidden_states = self.pre_feedforward_layernorm(hidden_states)
            hidden_states = self.mlp(hidden_states)
            hidden_states = self.post_feedforward_layernorm(hidden_states)
            hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs
    

class LayerSkipGemma2Model(Gemma2PreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`]

    Args:
        config: Gemma2Config
    """

    def __init__(self, config: Gemma2Config):
        super().__init__(config)

        # Set skip layer
        self.draft_mode = False
        self.skip_layer_ids = {"attn": [], "mlp": []}

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LayerSkipGemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        self.skip_layer_ids = skip_layer_ids

        for layer in self.layers:
            layer.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode

        for layer in self.layers:
            layer.set_draft_mode(_mode=_mode)

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[HybridCache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if use_cache and past_key_values is None and not self.training:
            batch_size, seq_len, _ = inputs_embeds.shape
            past_key_values = HybridCache(
                self.config,
                batch_size=batch_size,
                max_cache_len=seq_len,
                device=self.device,
                dtype=inputs_embeds.dtype,
            )

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )

        # embed positions
        hidden_states = inputs_embeds

        # normalized
        # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
        # See https://github.com/huggingface/transformers/pull/29402
        normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
        hidden_states = hidden_states * normalizer

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = past_key_values if use_cache else None

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: HybridCache,
        output_attentions: bool,
    ):
        # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
        # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
        # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
        # as it doesn't cause dynamic control issues.
        if self.config._attn_implementation == "flash_attention_2":
            return attention_mask

        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]
        if isinstance(past_key_values, HybridCache):
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )
        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        """
        Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
        `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.

        Args:
            attention_mask (`torch.Tensor`):
                A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
                `(batch_size, 1, query_length, key_value_length)`.
            sequence_length (`int`):
                The sequence length being processed.
            target_length (`int`):
                The target length: when generating with static cache, the mask should be as long as the static cache,
                to account for the 0 padding, the part of the cache that is not filled yet.
            dtype (`torch.dtype`):
                The dtype to use for the 4D attention mask.
            device (`torch.device`):
                The device to plcae the 4D attention mask on.
            cache_position (`torch.Tensor`):
                Indices depicting the position of the input sequence tokens in the sequence.
            batch_size (`torch.Tensor`):
                Batch size.
        """
        if attention_mask is not None and attention_mask.dim() == 4:
            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
            causal_mask = attention_mask
        else:
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask


class LayerSkipGemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)

        # Set skip layer
        self.draft_mode = False
        self.skip_layer_ids = {"attn": [], "mlp": []}

        self.model = LayerSkipGemma2Model(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        assert "attn" in skip_layer_ids and "mlp" in skip_layer_ids, "`skip_layer_ids` need to be set `attn` and `mlp`!"
        assert isinstance(skip_layer_ids["attn"], list), "`skip_layer_ids['attn']` need to be a list!"
        assert isinstance(skip_layer_ids["mlp"], list), "`skip_layer_ids['mlp']` need to be a list!"

        for attn_layer_idx in skip_layer_ids["attn"]:
            assert attn_layer_idx < len(self.model.layers), f"attn_layer_idx {attn_layer_idx} is out of Range ({len(self.model.layers)})" 
            
        for mlp_layer_idx in skip_layer_ids["mlp"]:
            assert mlp_layer_idx < len(self.model.layers), f"mlp_layer_idx {mlp_layer_idx} is out of Range ({len(self.model.layers)})"

        self.skip_layer_ids = skip_layer_ids
        self.model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

        print("skip_layer_ids:", self.skip_layer_ids)

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode
        self.model.set_draft_mode(_mode=_mode)

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[HybridCache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
        if self.config.final_logit_softcapping is not None:
            logits = logits / self.config.final_logit_softcapping
            logits = torch.tanh(logits)
            logits = logits * self.config.final_logit_softcapping

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        num_logits_to_keep=None,
        **kwargs,
    ):
        # Overwritten: has a special cache type, `HybridCache`

        # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
        # Exception 1: when passing input_embeds, input_ids may be missing entries
        # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
        if past_key_values is not None:
            if inputs_embeds is not None:  # Exception 1
                input_ids = input_ids[:, -cache_position.shape[0] :]
            elif input_ids.shape[1] != cache_position.shape[0]:  # Default case (the "else", a no op, is Exception 2)
                input_ids = input_ids[:, cache_position]
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]
                # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
                # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
                # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
                # batch size = 1 case, `position_ids` is already contiguous but with varying stride
                # which retriggers a capture.
                position_ids = position_ids.clone(memory_format=torch.contiguous_format)

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and cache_position[0] == 0:
            model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
        else:
            # The clone here is for the same reason as for `position_ids`.
            model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}

        if (
            isinstance(past_key_values, HybridCache)
            and attention_mask.ndim == 2
            and not self.config._attn_implementation == "flash_attention_2"
        ):
            if model_inputs["inputs_embeds"] is not None:
                batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
                device = model_inputs["inputs_embeds"].device
            else:
                batch_size, sequence_length = model_inputs["input_ids"].shape
                device = model_inputs["input_ids"].device

            attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
                attention_mask,
                sequence_length=sequence_length,
                target_length=past_key_values.get_max_cache_shape(),
                dtype=self.lm_head.weight.dtype,
                device=device,
                cache_position=cache_position,
                batch_size=batch_size,
            )

        if num_logits_to_keep is not None:
            model_inputs["num_logits_to_keep"] = num_logits_to_keep

        model_inputs.update(
            {
                "position_ids": position_ids,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "use_cache": use_cache,
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

自適應草稿離開機制

詳細的解釋還是要看我之前紀錄的筆記:[論文閱讀] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding,這粒就不再贅述。

簡單來說,就是一個自動控制 draft model 到底可以推測幾個 tokens 的動態機制。當 draft model 沒信心且當前接受不了,就把可以解碼的信心分數閾值提高些;反之若是當前的接受率達到標準,或許就可以嘗試讓 draft model 的信心分數閾值降低些,讓 draft model 在比較沒信心時依然可以繼續推測。

有點像是,如果今天一個小孩考試成績不好(接受率不高),我們就減少他玩手機的時間(調高解碼的信心分數閾值);反之如果考得不錯,則可以提升玩手機的時間。

以下是實現:

import torch


class AdaptiveDraftExitAduster:
    def __init__(
        self,
        target_matchness: float = 0.9,
        beta1: float = 0.5,
        beta2: float = 0.9,
        epsilon: float = 0.01,
        max_step_draft: int = 8,
    ):
        """Initialize DraftExitingAdjuster parameters
        :param target_matchness: matching degree target value
        :param beta1: sliding average coefficient for matching degree update
        :param beta2: th_stop_draft smooth update coefficient
        :param epsilon: Adjust the step size each time
        :param max_step_draft: The maximum number of steps for draft generation
        """

        self.save_init_status(
            target_matchness=target_matchness,
            beta1=beta1,
            beta2=beta2,
            epsilon=epsilon,
            max_step_draft=max_step_draft,
        )
        self.reset()

    def save_init_status(
        self,
        target_matchness: float,
        beta1: float,
        beta2: float,
        epsilon: float,
        max_step_draft: int,
    ) -> None:
        self.init_target_matchness = target_matchness
        self.init_beta1 = beta1
        self.init_beta2 = beta2
        self.init_epsilon = epsilon
        self.init_max_step_draft = max_step_draft

    def reset(self) -> None:
        self.target_matchness = self.init_target_matchness
        self.beta1 = self.init_beta1
        self.beta2 = self.init_beta2
        self.epsilon = self.init_epsilon
        self.max_step_draft = self.init_max_step_draft

        # Dynamic status
        self.curr_matchness = 0.0
        self.stop_draft_threshold = 0.5
        self.step_num = 0

    def update(self, num_matched_tokens, num_drafted_tokens) -> None:
        # Update matchness
        matchness = num_matched_tokens / num_drafted_tokens
        self.curr_matchness = self.beta1 * self.curr_matchness + (1 - self.beta1) * matchness

        # Calculate new exit threshold
        if num_drafted_tokens == self.max_step_draft:
            new_stop_draft_threshold = self.stop_draft_threshold

        elif self.curr_matchness <= self.target_matchness:
            new_stop_draft_threshold = self.stop_draft_threshold + self.epsilon
        else:
            new_stop_draft_threshold = self.stop_draft_threshold - self.epsilon

        self.stop_draft_threshold = self.beta2 * self.stop_draft_threshold + (1 - self.beta2) * new_stop_draft_threshold
        self.step_num += 1

    def should_exit(self, draft_prob: torch.FloatTensor) -> bool:
        return draft_prob < self.stop_draft_threshold

    def get_state(self):
        return {
            "curr_matchness": self.curr_matchness,
            "stop_draft_threshold": self.stop_draft_threshold,
            "step_num": self.step_num
        }


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance

貝氏優化探索最佳跳層策略

之前我也寫過一篇 透過貝氏優化去搜索 LayerSkip 模型的最佳跳層策略 的筆記,但當時顯然還不完善,我後續繼續補強,並以此作為實驗基礎。

以下是我貝氏優化檢索器的類別實現:

from typing import Dict, List, Callable, Optional

import gc
import copy
import json
import time

import optuna
import torch
from tqdm import tqdm

from utils.utils import AdaptiveDraftExitAduster


class LayerSkipStrategySearcher:
    def __init__(
        self,
        model: torch.nn.Module,
        tokenizer,
        device,
        drafter_speculative_decode: Callable,
        target_speculative_decode: Callable,
        adjuster: Optional[AdaptiveDraftExitAduster] = None,
    ) -> None:
        self.model = model
        if not hasattr(self.model, "set_draft_mode"):
            raise TypeError(f"{type(self.model)} is not a LayerSkip-liked model.")
        
        self.tokenizer = tokenizer
        self.device = device
        
        # Initialize skip layers
        self.reset_skip_layers()
        
        # Sampling methods
        self.drafter_speculative_decode = drafter_speculative_decode
        self.target_speculative_decode = target_speculative_decode
        
        # Threshold adjuster
        self.adjuster = adjuster
        
        # Cache for processed samples
        self._processed_samples = None

    def reset_skip_layers(self):
        """Reset skip layers to initial state"""
        self.skip_layer_ids = {
            "attn": [],
            "mlp": [],
        }
    
    @property
    def processed_samples(self):
        """Lazy loading of processed samples with proper cleanup"""
        if self._processed_samples is None:
            raise ValueError("Samples not initialized. Call prepare_samples first.")
        return self._processed_samples
    
    def prepare_samples(self, samples: List[List[Dict[str, str]]]):
        """Process and prepare samples with proper cleanup"""
        # Clear any existing processed samples
        if self._processed_samples is not None:
            del self._processed_samples
            self._processed_samples = None
            torch.cuda.empty_cache()
            gc.collect()
        
        # Process new samples
        self._processed_samples = [
            self.tokenizer(
                self.tokenizer.apply_chat_template(messages, tokenize=False),
                return_tensors="pt",
            ).to(self.device)
            for messages in samples
        ]

    def cleanup(self):
        """Clean up resources"""
        del self._processed_samples
        self._processed_samples = None
        torch.cuda.empty_cache()
        gc.collect()

    def optimize_acceptance_rate(
        self,
        samples: List[List[Dict[str, str]]],
        n_trials: int = 50,
        num_hidden_layers: int = 1,
    ) -> Dict:
        try:
            # Prepare samples
            self.prepare_samples(samples)
            
            # Define search space
            num_hidden_layers = getattr(self.model.config, "num_hidden_layers", num_hidden_layers)
            print(f"Total layers we can skip: {num_hidden_layers}")
            
            study = optuna.create_study(direction="maximize")
            study.optimize(
                lambda trial: self.objective_acceptance_rate(
                    trial=trial,
                    num_hidden_layers=num_hidden_layers,
                ),
                n_trials=n_trials,
                callbacks=[
                    lambda study, trial: print(f"Trial {trial.number}: {trial.value}"),
                    lambda study, trial: gc.collect(),  # Force garbage collection after each trial
                    lambda study, trial: torch.cuda.empty_cache()  # Clear CUDA cache after each trial
                ]
            )
            
            return {
                "best_params": study.best_params,
                "best_value": study.best_value
            }
        finally:
            self.cleanup()

    def optimize_speculative_speed(
        self,
        samples: List[List[Dict[str, str]]],
        n_trials: int = 50,
        num_hidden_layers: int = 1,
    ) -> Dict:
        try:
            # Prepare samples
            self.prepare_samples(samples)
            
            # Define search space
            num_hidden_layers = getattr(self.model.config, "num_hidden_layers", num_hidden_layers)
            print(f"Total layers we can skip: {num_hidden_layers}")
            
            study = optuna.create_study(direction="maximize")
            study.optimize(
                lambda trial: self.objective_speculative_speed(
                    trial=trial,
                    num_hidden_layers=num_hidden_layers,
                ),
                n_trials=n_trials,
                callbacks=[
                    lambda study, trial: print(f"Trial {trial.number}: {trial.value}"),
                    lambda study, trial: gc.collect(),
                    lambda study, trial: torch.cuda.empty_cache()
                ]
            )

            # Save best skip_layer_ids
            best_skip_layers = self._get_skip_layers(study.best_trial, num_hidden_layers)
            with open("skip_layer_ids.json", "w") as f:
                json.dump(best_skip_layers, f)
            
            return {
                "best_params": study.best_params,
                "best_value": study.best_value
            }
        finally:
            self.cleanup()

    def _run_inference(self, inputs, gamma=5, max_new_tokens=100):
        """Common inference logic for both objectives"""
        is_end = False

        trial_inputs = copy.deepcopy(inputs)
        raw_token_num = trial_inputs["input_ids"].shape[1]
        total_draft_tokens = 0
        total_accept_tokens = 0
        generated_tokens = 0
        
        while not is_end:
            # Draft model inference
            with torch.no_grad():  # Ensure no gradients are tracked
                target_inputs, draft_probs, real_generated_tokens = self.drafter_speculative_decode(
                    draft_model=self.model,
                    draft_tokenizer=self.tokenizer,
                    inputs=trial_inputs,
                    gamma=gamma,
                    confidence_threshold_adjuster=self.adjuster,
                )
            
            total_draft_tokens += real_generated_tokens if real_generated_tokens is not None else gamma
            
            # Target model inference
            with torch.no_grad():
                outputs, is_end, accept_tokens = self.target_speculative_decode(
                    target_model=self.model,
                    target_tokenizer=self.tokenizer,
                    inputs=target_inputs,
                    draft_probs=draft_probs,
                )
            
            if self.adjuster:
                self.adjuster.update(
                    num_matched_tokens=accept_tokens,
                    num_drafted_tokens=real_generated_tokens if real_generated_tokens is not None else gamma,
                )
            
            total_accept_tokens += accept_tokens
            trial_inputs = outputs
            
            generated_tokens = trial_inputs["input_ids"].shape[1] - raw_token_num
            if generated_tokens >= max_new_tokens:
                break
            
        # Free memory
        del target_inputs, draft_probs, trial_inputs
        torch.cuda.empty_cache()
        
        return total_accept_tokens, total_draft_tokens, generated_tokens

    def objective_speculative_speed(self, trial, num_hidden_layers: int):
        try:
            skip_layers = self._get_skip_layers(trial, num_hidden_layers)
            if not skip_layers:
                raise optuna.TrialPruned()
            
            self.model.set_skip_layer_ids(skip_layer_ids=skip_layers)
            
            if self.adjuster:
                self.adjuster.reset()
            
            start_time = time.time()
            total_generated_tokens = 0
            
            for inputs in tqdm(self.processed_samples):
                _, _, generated_tokens = self._run_inference(inputs)
                total_generated_tokens += generated_tokens
            
            token_per_second = total_generated_tokens / (time.time() - start_time)
            print(f"attn_skip: {skip_layers['attn']}, mlp_skip: {skip_layers['mlp']}, tokens/sec: {token_per_second}")
            
            return token_per_second
        
        except Exception as e:
            print(f"Trial failed with error: {str(e)}")
            raise optuna.TrialPruned()

    def objective_acceptance_rate(self, trial, num_hidden_layers: int):
        try:
            skip_layers = self._get_skip_layers(trial, num_hidden_layers)
            if not skip_layers:
                raise optuna.TrialPruned()
            
            self.model.set_skip_layer_ids(skip_layer_ids=skip_layers)
            
            total_accept_tokens_group = 0
            total_draft_tokens_group = 0
            
            for inputs in tqdm(self.processed_samples):
                accept_tokens, draft_tokens, _ = self._run_inference(inputs)
                total_accept_tokens_group += accept_tokens
                total_draft_tokens_group += draft_tokens
            
            accept_rate = total_accept_tokens_group / total_draft_tokens_group
            print(f"attn_skip: {skip_layers['attn']}, mlp_skip: {skip_layers['mlp']}, Accept Rate: {accept_rate}")
            
            return accept_rate
        
        except Exception as e:
            print(f"Trial failed with error: {str(e)}")
            raise optuna.TrialPruned()

    def _get_skip_layers(self, trial, num_hidden_layers: int) -> Dict[str, List[int]]:
        """Helper method to get skip layers configuration from trial"""
        skip_attn_layers = [i for i in range(num_hidden_layers) if trial.suggest_int(f"skip_attn_layer_{i}", 0, 1) == 1]
        skip_mlp_layers = [i for i in range(num_hidden_layers) if trial.suggest_int(f"skip_mlp_layer_{i}", 0, 1) == 1]
        
        if not skip_attn_layers and not skip_mlp_layers:
            return None
            
        return {
            "attn": skip_attn_layers,
            "mlp": skip_mlp_layers,
        }


接著是可以真正進行實驗了:

import optuna
from typing import Dict, Optional, Tuple

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import copy
import time

from datasets import load_dataset
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from layerskip_modeling.modeling_layerskip_gemma2 import LayerSkipGemma2ForCausalLM
from sampling.sampling import sample_next_token
from utils.utils import calculate_continuous_acceptance, AdaptiveDraftExitAduster
from utils.optimization_searcher import LayerSkipStrategySearcher


def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
    draft_mode: bool = True,
    confidence_threshold_adjuster: Optional[AdaptiveDraftExitAduster] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_model.set_draft_mode(draft_mode)
    draft_probs = []
    real_generated_tokens = 0

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

        real_generated_tokens += 1

        # Early exit
        if confidence_threshold_adjuster and confidence_threshold_adjuster.should_exit(draft_prob=probs[0, 0, next_tokens.item()]):
            # print(confidence_threshold_adjuster.get_state())
            break

    draft_model.set_draft_mode(False)

    return inputs, torch.cat(draft_probs, dim=1), real_generated_tokens


def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    target_model.set_draft_mode(False)
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


if __name__ == "__main__":
    # Load tokenizer and model
    pretrained_model_name_or_path = "../models/google--gemma-2-2b-it/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    model = LayerSkipGemma2ForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)

    # Load dataset
    dataset = load_dataset("shibing624/sharegpt_gpt4")
    samples = dataset["train"]["conversations"][:10]
    samples = [[{"role": sample[0]["from"].replace("human", "user").replace("gpt", "assistant"), "content": sample[0]["value"]}] for sample in samples]

    # Adaptive draft threshold adjuster
    confidence_threshold_adjuster = AdaptiveDraftExitAduster(
        target_matchness=0.5,
        beta1=0.5,
        beta2=0.9,
        epsilon=0.1,
        max_step_draft=8,
    )

    # Searcher
    searcher = LayerSkipStrategySearcher(
        model=model,
        tokenizer=tokenizer,
        device=device,
        drafter_speculative_decode=drafter_speculative_decode,
        target_speculative_decode=target_speculative_decode,
        adjuster=confidence_threshold_adjuster,
    )

    searcher.optimize_speculative_speed(
        samples=samples,
        n_trials=1000,
    )


Output:


在我的實驗中,效果最好的是第 846 次實驗,其優化目標的解碼速度達到了 24.816 tokens/sec

如果隨便選擇要跳的層數的話,基本只能在 16 - 18 tokens/sec 徘徊,可見有策略地使用貝氏優化去檢索跳層超參數還是很有幫助的。

不過,整體來說我的實驗結果是失敗的。怎麼說呢?其實看下一節就知道,原本直接拿模型來解碼,不使用 Sefl-Speculative Decoding 的解碼速度本身就是 29.647 tokens/sec 左右。


Self-Speculative Decoding 實驗結果

我一開始嘗試的架構是 gemma-2-2b-it,但是差不多在跳過超過 6 層開始,反而會出現減速的情況,跟論文中顯示跳層數量的圖一致,只是量級小的模型會更快撞上這個瓶頸。

我在 gemma-2-9b-it 的實驗結果上比 2b 好一些,但實際上仍然沒有成功加速、反而變慢了。以下第一個結果是 Self-Speculative Decoding 的結果,第二個是正常的 gemma-2-9b-it 的推理速度,最後一個則是跳層後 gemma-2-9b-it 的推理速度。

Generate token number: 102
Generate speed: 24.07665693982178 token/sec
Speculative Decoding Spent Time: 4.236468553543091 seconds.
Accept Rate: 0.4056603773584906

Generate token number: 100
Generate speed: 29.647649721089135 token/sec
Normal Target Model Decoding Spent Time: 3.37294864654541 seconds.

Generate token number: 100
Generate speed: 48.81880782264111 token/sec
Normal Draft Model Decoding Spent Time: 2.0483908653259277 seconds.


可以發現說,就算跳過了這麼多層,仍然是沒能比全部神經網路都用於推理快上兩倍,也就是說,推理時間跟接受率的關係可以直接得出以下這張表(以下時間單位轉換為 10 毫秒最直觀):

神經網路狀態原本的模型跳層狀態的模型加速所需要的時間(假設為全接受 + 額外加上模型一次推理的時間)
第 1 個 Token 所花的時間3.422 + 3.4 = 5.4
第 2 個 Token 所花的時間6.844 + 3.4 = 7.4
第 3 個 Token 所花的時間10.266 + 3.4 = 9.4
第 4 個 Token 所花的時間13.688 + 3.4 = 11.4
第 5 個 Token 所花的時間17.01010 + 3.4 = 13 .4

也就是說如果由跳層網路的 draft model 產生 5 個 Tokens,其接受率至少要 0.8 才能確實超越本來的自己,其實還滿嚴苛的。

不過反過來說,直覺上確實是越大的模型更容易藉由 Self-Speculative Decoding 進行加速,因為更大的模型基礎推理速度越慢,而因為權重夠多,跳過幾層的損失也會相對減少,所以更容易找到『損失不大 + 推理成本顯著減少』的甜蜜點。

當然,另外一個作法應該是像 Kangaroo 或是 Meta AI 的 LayerSkip 那樣,藉由訓練去調整跳層網路的性能損失,進而允許在損失足夠小的情況下進行更大幅度的加速。

老實說,真的滿想改進這個結果的,目前預計會著手進行 Kangaroo 的方式進行訓練,希望可以在不做太大改動的情況下真的使用 Self-Speculative Decoding 的方式加速。

最後,還是歡迎有緣讀到這篇的讀者,可以繼續關注我的 GitHub: https://github.com/ccs96307/fast-llm-inference,也非常歡迎隨時提供給我任何意見!


References


Read More

Leave a Reply取消回覆

Exit mobile version