Last Updated on 2024-11-12 by Clay
Introduction
Self-Speculative Decoding is a variant of Speculative Decoding. The original Speculative Decoding method uses a draft model to optimize the inference of the target model. The draft model, which is typically distilled from the target model, offers similar output quality but with several times faster inference speed.
Once the draft model predicts a series of candidate token sequences, the target model verifies these sequences by predicting the next token and obtaining the probability distribution for each previous token predicted by the draft model. With specific verification algorithms, tokens predicted by the draft model can be accepted, enabling the target model to decode multiple tokens in one inference cycle and thus achieve acceleration.
Self-Speculative Decoding addresses the additional VRAM cost associated with loading a separate draft model. Instead, it uses parts of its own neural network layers to emulate a draft model and then verifies using the full network layers as the target model. Detailed explanations can be found in the paper Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding.
Today, I’m documenting how to create a LayerSkipModel
by simply calling the original Transformer code and adding a skip mechanism. This allows us to bypass self-attention and MLP layers as needed.
However, I haven’t yet completed the mechanism for selecting skip layers through Bayesian Optimization as described in the paper, so I can’t provide results yet.
For a more basic introduction to Speculative Decoding, you can refer to my previous notes:
- [Paper Review] Fast Inference from Transformers via Speculative Decoding
- Notes on Speculative Decoding Implementation (with Simple Experimental Results)
Additionally, I’ll upload my implementation on GitHub: https://github.com/ccs96307/fast-llm-inference. This repository contains various implementations to accelerate inference, along with references to the papers I consulted. Feel free to bookmark or star it, and I welcome everyone to check it out!
LayerSkip Architecture Implementation
Basically, we only need to rewrite three classes: LlamaDecoderLayer
, LlamaModel
, and LlamaForCausalLM
. I’ve largely copied HuggingFace Transformers’ implementation, so the copyright belongs to them.
Of course, inheritance could be used, but this level of abstraction works perfectly for me while keeping the code additions minimal.
In essence, I did three main things:
- Defined the three classes with the
LayerSkip
prefix and initializedself.draft_mode
andself.skip_layer_ids
. Each model initialization aligns with its components (LayerSkipLlamaForCausalLM contains LayerSkipLlamaModel, and LayerSkipLlamaModel initializes with LayerSkipLlamaDecoderLayer). - Added
set_skip_layer_ids()
andset_draft_mode()
methods to the internal methods of each class. These settings propagate from the outermostLayerSkipLlamaForCausalLM
down to theLayerSkipLlamaDecoderLayer
. - In the
forward()
method ofLayerSkipLlamaDecoderLayer
, decisions to skip the attention mechanism or MLP layers are made based on whetherself.draft_mode
is enabled and if the current decoding layer is amongself.skip_layer_ids
.
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import (
LlamaConfig,
LlamaPreTrainedModel,
GenerationMixin,
)
from transformers.models.llama.modeling_llama import (
LlamaMLP,
LlamaRMSNorm,
LlamaRotaryEmbedding,
LLAMA_ATTENTION_CLASSES,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
class LayerSkipLlamaDecoderLayer(torch.nn.Module):
def __init__(
self,
config: LlamaConfig,
layer_idx: int,
):
super().__init__()
# Set skip layer
skip_layer_ids = {"attn": [], "mlp": []}
self.draft_mode = False
self.skip_layer_ids = skip_layer_ids
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
self.skip_layer_ids = skip_layer_ids
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
if self.draft_mode and self.layer_idx in self.skip_layer_ids["attn"]:
hidden_states = residual
self_attn_weights = None
present_key_value = None
else:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
if self.draft_mode and self.layer_idx in self.skip_layer_ids["mlp"]:
hidden_states = residual
else:
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class LayerSkipLlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
# Set skip layer
skip_layer_ids = {"attn": [], "mlp": []}
self.draft_mode = False
self.skip_layer_ids = skip_layer_ids
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = torch.nn.ModuleList(
[LayerSkipLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
self.skip_layer_ids = skip_layer_ids
for layer in self.layers:
layer.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
for layer in self.layers:
layer.set_draft_mode(_mode=_mode)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class LayerSkipLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
# Set skip layer
skip_layer_ids = {"attn": [], "mlp": []}
self.draft_mode = False
self.skip_layer_ids = skip_layer_ids
self.model = LayerSkipLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
assert "attn" in skip_layer_ids and "mlp" in skip_layer_ids, "`skip_layer_ids` need to be set `attn` and `mlp`!"
assert isinstance(skip_layer_ids["attn"], list), "`skip_layer_ids['attn']` need to be a list!"
assert isinstance(skip_layer_ids["mlp"], list), "`skip_layer_ids['mlp']` need to be a list!"
for attn_layer_idx in skip_layer_ids["attn"]:
assert attn_layer_idx < len(self.model.layers), f"attn_layer_idx {attn_layer_idx} is out of Range ({len(self.model.layers)})"
for mlp_layer_idx in skip_layer_ids["mlp"]:
assert mlp_layer_idx < len(self.model.layers), f"mlp_layer_idx {mlp_layer_idx} is out of Range ({len(self.model.layers)})"
self.skip_layer_ids = skip_layer_ids
self.model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
print("skip_layer_ids:", self.skip_layer_ids)
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
self.model.set_draft_mode(_mode=_mode)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
Once completed, let’s test the results!
Testing the Effectiveness of the LayerSkip Model
Using a 1.7B model, it seems that skipping only a few layers significantly impacts performance. This may explain why Meta AI's LayerSkip models require training to compensate.
Still, my basic calculations show that LayerSkip indeed saves time! (Which is expected, as some computations are simply omitted.)
Currently, I’ve reached the stage where, after choosing which layers to skip, I can test the effects of Self-Speculative Decoding. However, I still need to implement a method for selecting the layers to skip.
Below is a simple test comparing the original model and the model with a few skipped layers. I chose to skip the attention mechanism and MLP layers at layers 2, 15, and 18.
import time
import torch
from transformers import AutoTokenizer
from layerskip_modeling.modeling_layerskip_llama import LayerSkipLlamaForCausalLM
if __name__ == "__main__":
pretrained_model_name_or_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = LayerSkipLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)
skip_layer_ids = {
"attn": [
2,
15,
18,
],
"mlp": [
2,
15,
18,
]
}
model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
messages = [
[
{
"role": "user",
"content": "What is the capital of Taiwan. And why?",
},
],
]
# Tokenize
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
prompt_token_num = inputs["input_ids"].shape[-1]
# Original Model
model.set_draft_mode(False)
start_time = time.time()
outputs = model.generate(**inputs, max_new_tokens=512)
total_token_num = outputs.shape[-1]
completion_token_num = total_token_num - prompt_token_num
cost_time = time.time() - start_time
token_per_second = completion_token_num / cost_time
response = tokenizer.batch_decode(outputs)[0]
print(f"{'='*15} Original Model {'='*15}")
print(response)
print()
print(f"Completion Token Number: {completion_token_num}")
print(f"Cost Time: {cost_time}, Speed: {token_per_second} token/sec\n")
# LayerSkip Model
model.set_draft_mode(True)
start_time = time.time()
outputs = model.generate(**inputs, max_new_tokens=512)
total_token_num = outputs.shape[-1]
completion_token_num = total_token_num - prompt_token_num
cost_time = time.time() - start_time
token_per_second = completion_token_num / cost_time
response = tokenizer.batch_decode(outputs)[0]
print(f"{'='*15} LayerSkip Model {'='*15}")
print(response)
print()
print(f"Completion Token Number: {completion_token_num}")
print(f"Cost Time: {cost_time}, Speed: {token_per_second} token/sec\n")
Output:
skip_layer_ids: {'attn': [2, 15, 18], 'mlp': [2, 15, 18]}
=============== Original Model ===============
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei. It is the largest city in Taiwan and serves as the political, economic, and cultural center of the country. The reason for this is that Taipei was established as the capital city in 1949, following the Chinese Civil War, when the government of the Republic of China (ROC) relocated from mainland China to Taiwan. This decision was made to ensure the continuity of the ROC's political and administrative functions, and to maintain its claim to the entirety of China.<|im_end|>
Completion Token Number: 110
Cost Time: 2.2670738697052, Speed: 48.52069509949576 token/sec
=============== LayerSkip Model ===============
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei.<|im_end|>
Completion Token Number: 13
Cost Time: 0.1961832046508789, Speed: 66.26459193147736 token/sec
In my opinion, the original model definitely provides a more comprehensive answer, as it explains why Taipei is Taiwan’s capital. However, the LayerSkip model still delivers coherent responses and speeds up from 48.52 tokens per second to 66.26 tokens per second.
Overall, this was an interesting implementation, and I plan to continue refining it until I have fully implemented Self-Speculative Decoding.
References
- Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding
- Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting