Last Updated on 2024-11-17 by Clay
在過去的一週裡,我抽空按照論文 Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding 的思路嘗試復現了一遍自推測性解碼(Self-Speculative Decoding),包含以下模組:
- 跳層解碼的 Decoder-only Transformer 模型(主要以 Llama 和 Gemma-2 兩種架構為主)
- 自適應草稿離開機制
- 貝氏優化探索最佳跳層策略(尋找怎樣的搭配才會是最好的草稿模型)
- Self-Speculative Decoding —— 完成只靠模型自身的加速
以下我們分別講講這些模組。若是想了解論文的簡要介紹,也可以參考我之前的筆記:[論文閱讀] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding
跳層解碼的 Decoder-only Transformer 模型
由於在 HuggingFace transformers
套件中的模型實作,本就有在解碼層中傳遞該層的編號(畢竟可能需要儲存每一層的 KV Cache),所以我將模型架構多儲存了一個 skip_layer_ids
,當該層的編號是儲存在該列表中,即會在草稿模式啟用時跳過該層的計算。
以下是我對 Gemma-2 模型架構的改造,並且是基於比較新的 transformers
版本。
# coding=utf-8
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, HybridCache
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers import Gemma2Config
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Gemma2PreTrainedModel(PreTrainedModel):
config_class = Gemma2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["Gemma2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = False
_supports_static_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False):
"""
Overloads `PreTrainedModel._check_and_enable_sdpa` so as to DISABLE torch SDPA by default on Gemma2 models.
SDPA reduces the model performance on Gemma2 because of the logits softcapping.
"""
config = super()._check_and_enable_sdpa(config, hard_check_only=hard_check_only)
# if using the default path -> swap sdpa by eager
if not hard_check_only and config._attn_implementation == "sdpa":
config._attn_implementation = "eager"
return config
class Gemma2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Gemma2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.scaling = config.query_pre_attn_scalar**-0.5
if self.hidden_size % self.num_heads != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self.sliding_window = config.sliding_window if not bool(layer_idx % 2) else None
self.rotary_emb = Gemma2RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Gemma2SdpaAttention(Gemma2Attention):
"""
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Gemma2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Gemma2Attention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"sliding_window": self.sliding_window,
"cache_position": cache_position,
}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=is_causal,
scale=self.scaling,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
GEMMA2_ATTENTION_CLASSES = {
"eager": Gemma2Attention,
"sdpa": Gemma2SdpaAttention,
}
class Gemma2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class Gemma2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst Gemma2 is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
class LayerSkipGemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
# Set skip layer
self.draft_mode = False
self.skip_layer_ids = {"attn": [], "mlp": []}
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.config = config
self.is_sliding = not bool(layer_idx % 2)
self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.sliding_window = config.sliding_window
self.post_attention_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
self.skip_layer_ids = skip_layer_ids
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
if self.config._attn_implementation == "flash_attention_2":
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
min_dtype = torch.finfo(hidden_states.dtype).min
sliding_window_mask = torch.tril(
torch.ones_like(attention_mask, dtype=torch.bool), diagonal=-self.sliding_window
)
attention_mask = torch.where(sliding_window_mask, min_dtype, attention_mask)
if attention_mask.shape[-1] <= 1: # when decoding
attention_mask = attention_mask[:, :, :, -self.sliding_window :]
residual = hidden_states
if self.draft_mode and self.layer_idx in self.skip_layer_ids["attn"]:
hidden_states = residual
self_attn_weights = None
present_key_value = None
else:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
if self.draft_mode and self.layer_idx in self.skip_layer_ids["mlp"]:
hidden_states = residual
else:
hidden_states = self.pre_feedforward_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class LayerSkipGemma2Model(Gemma2PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`]
Args:
config: Gemma2Config
"""
def __init__(self, config: Gemma2Config):
super().__init__(config)
# Set skip layer
self.draft_mode = False
self.skip_layer_ids = {"attn": [], "mlp": []}
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LayerSkipGemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
self.skip_layer_ids = skip_layer_ids
for layer in self.layers:
layer.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
for layer in self.layers:
layer.set_draft_mode(_mode=_mode)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None and not self.training:
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
# embed positions
hidden_states = inputs_embeds
# normalized
# Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = past_key_values if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: HybridCache,
output_attentions: bool,
):
# Flash Attention currently doesn't support static cache but Gemma2 work only with static cache.
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
if self.config._attn_implementation == "flash_attention_2":
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if isinstance(past_key_values, HybridCache):
target_length = past_key_values.get_max_cache_shape()
else:
target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1]
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class LayerSkipGemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
# Set skip layer
self.draft_mode = False
self.skip_layer_ids = {"attn": [], "mlp": []}
self.model = LayerSkipGemma2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
assert "attn" in skip_layer_ids and "mlp" in skip_layer_ids, "`skip_layer_ids` need to be set `attn` and `mlp`!"
assert isinstance(skip_layer_ids["attn"], list), "`skip_layer_ids['attn']` need to be a list!"
assert isinstance(skip_layer_ids["mlp"], list), "`skip_layer_ids['mlp']` need to be a list!"
for attn_layer_idx in skip_layer_ids["attn"]:
assert attn_layer_idx < len(self.model.layers), f"attn_layer_idx {attn_layer_idx} is out of Range ({len(self.model.layers)})"
for mlp_layer_idx in skip_layer_ids["mlp"]:
assert mlp_layer_idx < len(self.model.layers), f"mlp_layer_idx {mlp_layer_idx} is out of Range ({len(self.model.layers)})"
self.skip_layer_ids = skip_layer_ids
self.model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
print("skip_layer_ids:", self.skip_layer_ids)
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
self.model.set_draft_mode(_mode=_mode)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[HybridCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
if self.config.final_logit_softcapping is not None:
logits = logits / self.config.final_logit_softcapping
logits = torch.tanh(logits)
logits = logits * self.config.final_logit_softcapping
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=None,
**kwargs,
):
# Overwritten: has a special cache type, `HybridCache`
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s
# `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride
# during the decoding. Here, simply using `.contiguous()` is not sufficient as in the
# batch size = 1 case, `position_ids` is already contiguous but with varying stride
# which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if (
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
device = model_inputs["inputs_embeds"].device
else:
batch_size, sequence_length = model_inputs["input_ids"].shape
device = model_inputs["input_ids"].device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
)
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
}
)
return model_inputs
自適應草稿離開機制
詳細的解釋還是要看我之前紀錄的筆記:[論文閱讀] Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding,這粒就不再贅述。
簡單來說,就是一個自動控制 draft model 到底可以推測幾個 tokens 的動態機制。當 draft model 沒信心且當前接受不了,就把可以解碼的信心分數閾值提高些;反之若是當前的接受率達到標準,或許就可以嘗試讓 draft model 的信心分數閾值降低些,讓 draft model 在比較沒信心時依然可以繼續推測。
有點像是,如果今天一個小孩考試成績不好(接受率不高),我們就減少他玩手機的時間(調高解碼的信心分數閾值);反之如果考得不錯,則可以提升玩手機的時間。
以下是實現:
import torch
class AdaptiveDraftExitAduster:
def __init__(
self,
target_matchness: float = 0.9,
beta1: float = 0.5,
beta2: float = 0.9,
epsilon: float = 0.01,
max_step_draft: int = 8,
):
"""Initialize DraftExitingAdjuster parameters
:param target_matchness: matching degree target value
:param beta1: sliding average coefficient for matching degree update
:param beta2: th_stop_draft smooth update coefficient
:param epsilon: Adjust the step size each time
:param max_step_draft: The maximum number of steps for draft generation
"""
self.save_init_status(
target_matchness=target_matchness,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
max_step_draft=max_step_draft,
)
self.reset()
def save_init_status(
self,
target_matchness: float,
beta1: float,
beta2: float,
epsilon: float,
max_step_draft: int,
) -> None:
self.init_target_matchness = target_matchness
self.init_beta1 = beta1
self.init_beta2 = beta2
self.init_epsilon = epsilon
self.init_max_step_draft = max_step_draft
def reset(self) -> None:
self.target_matchness = self.init_target_matchness
self.beta1 = self.init_beta1
self.beta2 = self.init_beta2
self.epsilon = self.init_epsilon
self.max_step_draft = self.init_max_step_draft
# Dynamic status
self.curr_matchness = 0.0
self.stop_draft_threshold = 0.5
self.step_num = 0
def update(self, num_matched_tokens, num_drafted_tokens) -> None:
# Update matchness
matchness = num_matched_tokens / num_drafted_tokens
self.curr_matchness = self.beta1 * self.curr_matchness + (1 - self.beta1) * matchness
# Calculate new exit threshold
if num_drafted_tokens == self.max_step_draft:
new_stop_draft_threshold = self.stop_draft_threshold
elif self.curr_matchness <= self.target_matchness:
new_stop_draft_threshold = self.stop_draft_threshold + self.epsilon
else:
new_stop_draft_threshold = self.stop_draft_threshold - self.epsilon
self.stop_draft_threshold = self.beta2 * self.stop_draft_threshold + (1 - self.beta2) * new_stop_draft_threshold
self.step_num += 1
def should_exit(self, draft_prob: torch.FloatTensor) -> bool:
return draft_prob < self.stop_draft_threshold
def get_state(self):
return {
"curr_matchness": self.curr_matchness,
"stop_draft_threshold": self.stop_draft_threshold,
"step_num": self.step_num
}
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
貝氏優化探索最佳跳層策略
之前我也寫過一篇 透過貝氏優化去搜索 LayerSkip 模型的最佳跳層策略 的筆記,但當時顯然還不完善,我後續繼續補強,並以此作為實驗基礎。
以下是我貝氏優化檢索器的類別實現:
from typing import Dict, List, Callable, Optional
import gc
import copy
import json
import time
import optuna
import torch
from tqdm import tqdm
from utils.utils import AdaptiveDraftExitAduster
class LayerSkipStrategySearcher:
def __init__(
self,
model: torch.nn.Module,
tokenizer,
device,
drafter_speculative_decode: Callable,
target_speculative_decode: Callable,
adjuster: Optional[AdaptiveDraftExitAduster] = None,
) -> None:
self.model = model
if not hasattr(self.model, "set_draft_mode"):
raise TypeError(f"{type(self.model)} is not a LayerSkip-liked model.")
self.tokenizer = tokenizer
self.device = device
# Initialize skip layers
self.reset_skip_layers()
# Sampling methods
self.drafter_speculative_decode = drafter_speculative_decode
self.target_speculative_decode = target_speculative_decode
# Threshold adjuster
self.adjuster = adjuster
# Cache for processed samples
self._processed_samples = None
def reset_skip_layers(self):
"""Reset skip layers to initial state"""
self.skip_layer_ids = {
"attn": [],
"mlp": [],
}
@property
def processed_samples(self):
"""Lazy loading of processed samples with proper cleanup"""
if self._processed_samples is None:
raise ValueError("Samples not initialized. Call prepare_samples first.")
return self._processed_samples
def prepare_samples(self, samples: List[List[Dict[str, str]]]):
"""Process and prepare samples with proper cleanup"""
# Clear any existing processed samples
if self._processed_samples is not None:
del self._processed_samples
self._processed_samples = None
torch.cuda.empty_cache()
gc.collect()
# Process new samples
self._processed_samples = [
self.tokenizer(
self.tokenizer.apply_chat_template(messages, tokenize=False),
return_tensors="pt",
).to(self.device)
for messages in samples
]
def cleanup(self):
"""Clean up resources"""
del self._processed_samples
self._processed_samples = None
torch.cuda.empty_cache()
gc.collect()
def optimize_acceptance_rate(
self,
samples: List[List[Dict[str, str]]],
n_trials: int = 50,
num_hidden_layers: int = 1,
) -> Dict:
try:
# Prepare samples
self.prepare_samples(samples)
# Define search space
num_hidden_layers = getattr(self.model.config, "num_hidden_layers", num_hidden_layers)
print(f"Total layers we can skip: {num_hidden_layers}")
study = optuna.create_study(direction="maximize")
study.optimize(
lambda trial: self.objective_acceptance_rate(
trial=trial,
num_hidden_layers=num_hidden_layers,
),
n_trials=n_trials,
callbacks=[
lambda study, trial: print(f"Trial {trial.number}: {trial.value}"),
lambda study, trial: gc.collect(), # Force garbage collection after each trial
lambda study, trial: torch.cuda.empty_cache() # Clear CUDA cache after each trial
]
)
return {
"best_params": study.best_params,
"best_value": study.best_value
}
finally:
self.cleanup()
def optimize_speculative_speed(
self,
samples: List[List[Dict[str, str]]],
n_trials: int = 50,
num_hidden_layers: int = 1,
) -> Dict:
try:
# Prepare samples
self.prepare_samples(samples)
# Define search space
num_hidden_layers = getattr(self.model.config, "num_hidden_layers", num_hidden_layers)
print(f"Total layers we can skip: {num_hidden_layers}")
study = optuna.create_study(direction="maximize")
study.optimize(
lambda trial: self.objective_speculative_speed(
trial=trial,
num_hidden_layers=num_hidden_layers,
),
n_trials=n_trials,
callbacks=[
lambda study, trial: print(f"Trial {trial.number}: {trial.value}"),
lambda study, trial: gc.collect(),
lambda study, trial: torch.cuda.empty_cache()
]
)
# Save best skip_layer_ids
best_skip_layers = self._get_skip_layers(study.best_trial, num_hidden_layers)
with open("skip_layer_ids.json", "w") as f:
json.dump(best_skip_layers, f)
return {
"best_params": study.best_params,
"best_value": study.best_value
}
finally:
self.cleanup()
def _run_inference(self, inputs, gamma=5, max_new_tokens=100):
"""Common inference logic for both objectives"""
is_end = False
trial_inputs = copy.deepcopy(inputs)
raw_token_num = trial_inputs["input_ids"].shape[1]
total_draft_tokens = 0
total_accept_tokens = 0
generated_tokens = 0
while not is_end:
# Draft model inference
with torch.no_grad(): # Ensure no gradients are tracked
target_inputs, draft_probs, real_generated_tokens = self.drafter_speculative_decode(
draft_model=self.model,
draft_tokenizer=self.tokenizer,
inputs=trial_inputs,
gamma=gamma,
confidence_threshold_adjuster=self.adjuster,
)
total_draft_tokens += real_generated_tokens if real_generated_tokens is not None else gamma
# Target model inference
with torch.no_grad():
outputs, is_end, accept_tokens = self.target_speculative_decode(
target_model=self.model,
target_tokenizer=self.tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
)
if self.adjuster:
self.adjuster.update(
num_matched_tokens=accept_tokens,
num_drafted_tokens=real_generated_tokens if real_generated_tokens is not None else gamma,
)
total_accept_tokens += accept_tokens
trial_inputs = outputs
generated_tokens = trial_inputs["input_ids"].shape[1] - raw_token_num
if generated_tokens >= max_new_tokens:
break
# Free memory
del target_inputs, draft_probs, trial_inputs
torch.cuda.empty_cache()
return total_accept_tokens, total_draft_tokens, generated_tokens
def objective_speculative_speed(self, trial, num_hidden_layers: int):
try:
skip_layers = self._get_skip_layers(trial, num_hidden_layers)
if not skip_layers:
raise optuna.TrialPruned()
self.model.set_skip_layer_ids(skip_layer_ids=skip_layers)
if self.adjuster:
self.adjuster.reset()
start_time = time.time()
total_generated_tokens = 0
for inputs in tqdm(self.processed_samples):
_, _, generated_tokens = self._run_inference(inputs)
total_generated_tokens += generated_tokens
token_per_second = total_generated_tokens / (time.time() - start_time)
print(f"attn_skip: {skip_layers['attn']}, mlp_skip: {skip_layers['mlp']}, tokens/sec: {token_per_second}")
return token_per_second
except Exception as e:
print(f"Trial failed with error: {str(e)}")
raise optuna.TrialPruned()
def objective_acceptance_rate(self, trial, num_hidden_layers: int):
try:
skip_layers = self._get_skip_layers(trial, num_hidden_layers)
if not skip_layers:
raise optuna.TrialPruned()
self.model.set_skip_layer_ids(skip_layer_ids=skip_layers)
total_accept_tokens_group = 0
total_draft_tokens_group = 0
for inputs in tqdm(self.processed_samples):
accept_tokens, draft_tokens, _ = self._run_inference(inputs)
total_accept_tokens_group += accept_tokens
total_draft_tokens_group += draft_tokens
accept_rate = total_accept_tokens_group / total_draft_tokens_group
print(f"attn_skip: {skip_layers['attn']}, mlp_skip: {skip_layers['mlp']}, Accept Rate: {accept_rate}")
return accept_rate
except Exception as e:
print(f"Trial failed with error: {str(e)}")
raise optuna.TrialPruned()
def _get_skip_layers(self, trial, num_hidden_layers: int) -> Dict[str, List[int]]:
"""Helper method to get skip layers configuration from trial"""
skip_attn_layers = [i for i in range(num_hidden_layers) if trial.suggest_int(f"skip_attn_layer_{i}", 0, 1) == 1]
skip_mlp_layers = [i for i in range(num_hidden_layers) if trial.suggest_int(f"skip_mlp_layer_{i}", 0, 1) == 1]
if not skip_attn_layers and not skip_mlp_layers:
return None
return {
"attn": skip_attn_layers,
"mlp": skip_mlp_layers,
}
接著是可以真正進行實驗了:
import optuna
from typing import Dict, Optional, Tuple
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import copy
import time
from datasets import load_dataset
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from layerskip_modeling.modeling_layerskip_gemma2 import LayerSkipGemma2ForCausalLM
from sampling.sampling import sample_next_token
from utils.utils import calculate_continuous_acceptance, AdaptiveDraftExitAduster
from utils.optimization_searcher import LayerSkipStrategySearcher
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
draft_mode: bool = True,
confidence_threshold_adjuster: Optional[AdaptiveDraftExitAduster] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_model.set_draft_mode(draft_mode)
draft_probs = []
real_generated_tokens = 0
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
real_generated_tokens += 1
# Early exit
if confidence_threshold_adjuster and confidence_threshold_adjuster.should_exit(draft_prob=probs[0, 0, next_tokens.item()]):
# print(confidence_threshold_adjuster.get_state())
break
draft_model.set_draft_mode(False)
return inputs, torch.cat(draft_probs, dim=1), real_generated_tokens
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
target_model.set_draft_mode(False)
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
if __name__ == "__main__":
# Load tokenizer and model
pretrained_model_name_or_path = "../models/google--gemma-2-2b-it/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = LayerSkipGemma2ForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)
# Load dataset
dataset = load_dataset("shibing624/sharegpt_gpt4")
samples = dataset["train"]["conversations"][:10]
samples = [[{"role": sample[0]["from"].replace("human", "user").replace("gpt", "assistant"), "content": sample[0]["value"]}] for sample in samples]
# Adaptive draft threshold adjuster
confidence_threshold_adjuster = AdaptiveDraftExitAduster(
target_matchness=0.5,
beta1=0.5,
beta2=0.9,
epsilon=0.1,
max_step_draft=8,
)
# Searcher
searcher = LayerSkipStrategySearcher(
model=model,
tokenizer=tokenizer,
device=device,
drafter_speculative_decode=drafter_speculative_decode,
target_speculative_decode=target_speculative_decode,
adjuster=confidence_threshold_adjuster,
)
searcher.optimize_speculative_speed(
samples=samples,
n_trials=1000,
)
Output:
在我的實驗中,效果最好的是第 846 次實驗,其優化目標的解碼速度達到了 24.816 tokens/sec。
如果隨便選擇要跳的層數的話,基本只能在 16 - 18 tokens/sec 徘徊,可見有策略地使用貝氏優化去檢索跳層超參數還是很有幫助的。
不過,整體來說我的實驗結果是失敗的。怎麼說呢?其實看下一節就知道,原本直接拿模型來解碼,不使用 Sefl-Speculative Decoding 的解碼速度本身就是 29.647 tokens/sec 左右。
Self-Speculative Decoding 實驗結果
我一開始嘗試的架構是 gemma-2-2b-it,但是差不多在跳過超過 6 層開始,反而會出現減速的情況,跟論文中顯示跳層數量的圖一致,只是量級小的模型會更快撞上這個瓶頸。
我在 gemma-2-9b-it 的實驗結果上比 2b 好一些,但實際上仍然沒有成功加速、反而變慢了。以下第一個結果是 Self-Speculative Decoding 的結果,第二個是正常的 gemma-2-9b-it 的推理速度,最後一個則是跳層後 gemma-2-9b-it 的推理速度。
Generate token number: 102
Generate speed: 24.07665693982178 token/sec
Speculative Decoding Spent Time: 4.236468553543091 seconds.
Accept Rate: 0.4056603773584906
Generate token number: 100
Generate speed: 29.647649721089135 token/sec
Normal Target Model Decoding Spent Time: 3.37294864654541 seconds.
Generate token number: 100
Generate speed: 48.81880782264111 token/sec
Normal Draft Model Decoding Spent Time: 2.0483908653259277 seconds.
可以發現說,就算跳過了這麼多層,仍然是沒能比全部神經網路都用於推理快上兩倍,也就是說,推理時間跟接受率的關係可以直接得出以下這張表(以下時間單位轉換為 10 毫秒最直觀):
神經網路狀態 | 原本的模型 | 跳層狀態的模型 | 加速所需要的時間(假設為全接受 + 額外加上模型一次推理的時間) |
---|---|---|---|
第 1 個 Token 所花的時間 | 3.4 | 2 | 2 + 3.4 = 5.4 |
第 2 個 Token 所花的時間 | 6.8 | 4 | 4 + 3.4 = 7.4 |
第 3 個 Token 所花的時間 | 10.2 | 6 | 6 + 3.4 = 9.4 |
第 4 個 Token 所花的時間 | 13.6 | 8 | 8 + 3.4 = 11.4 |
第 5 個 Token 所花的時間 | 17.0 | 10 | 10 + 3.4 = 13 .4 |
也就是說如果由跳層網路的 draft model 產生 5 個 Tokens,其接受率至少要 0.8 才能確實超越本來的自己,其實還滿嚴苛的。
不過反過來說,直覺上確實是越大的模型更容易藉由 Self-Speculative Decoding 進行加速,因為更大的模型基礎推理速度越慢,而因為權重夠多,跳過幾層的損失也會相對減少,所以更容易找到『損失不大 + 推理成本顯著減少』的甜蜜點。
當然,另外一個作法應該是像 Kangaroo 或是 Meta AI 的 LayerSkip 那樣,藉由訓練去調整跳層網路的性能損失,進而允許在損失足夠小的情況下進行更大幅度的加速。
老實說,真的滿想改進這個結果的,目前預計會著手進行 Kangaroo 的方式進行訓練,希望可以在不做太大改動的情況下真的使用 Self-Speculative Decoding 的方式加速。
最後,還是歡迎有緣讀到這篇的讀者,可以繼續關注我的 GitHub: https://github.com/ccs96307/fast-llm-inference,也非常歡迎隨時提供給我任何意見!
References
- Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting
- LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding