Last Updated on 2024-12-10 by Clay
Introduction
Kangaroo is an implementation of Self-Speculative Decoding that introduces a trainable adapter layer. Over the past few weeks, I have been working on fine-tuning its adapter layer and have achieved some preliminary results, which I am documenting here.
Kangaroo Architecture
Kangaroo is a variant of Self-Speculative Decoding, which itself is a branch of Speculative Decoding.
In the original Speculative Decoding method, there is a draft model and a target model (the one we aim to accelerate). Typically, the draft model and target model share the same vocabulary and the draft model performs inference much faster. The draft model speculatively generates several tokens, which are then validated by the target model in a single pass.
In the original paper (Fast Inference from Transformers via Speculative Decoding), the research team designed a validation algorithm for the target model to ensure that the output probability distribution after speculative decoding matches that of the original target model, thereby achieving acceleration without altering the decoding results.
Self-Speculative Decoding (Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding) leverages parts of the target model's own network as the draft model. Simply put, in a Transformer block, each layer consists of an attention layer and an MLP layer, and we can choose to skip these layers selectively.
The choice of which layers to skip is critical. Skipping too few layers won't yield significant speedup, thus failing to accelerate the target model. Conversely, skipping too many layers can cause the draft model's decoding to diverge significantly from the target model, potentially rendering the draft model's outputs unacceptable to the target model.
To balance the trade-off between inference speed and accuracy, Self-Speculative Decoding employs Bayesian Optimization to explore layer-skipping strategies that satisfy both requirements.
I previously conducted an experiment: Self-Speculative Decoding Implementation: LayerSkip Model, Bayesian Optimization, and Adaptive Draft-Exiting Mechanism (with Gemma-2-9B-it Results).
However, the experimental results were dismal. Due to hardware constraints, I could only test up to the 9B scale, which showed no acceleration at all. In retrospect, the results presented in the paper starting from 13B models make sense; perhaps smaller-scale models do not perform well using partial networks as draft models.
Thus, I decided to take on the challenge of implementing Kangaroo. The principle of Kangaroo aligns with Self-Speculative Decoding, but it allows for model training. Of course, training the draft model might compromise the target model's performance, as the draft model is part of the target model, and fine-tuning could lead to some loss of capability in the target model.
This led to the following architecture proposed by Kangaroo:
The draft model in Kangaroo primarily consists of shallow sub-networks. The hidden states are then processed through an additional trainable Adapter-Network, which bridges the shallow sub-network and the LM Head to produce logits.
This Adapter-Network is trainable, and during training, we compute the logits output from both the Adapter-Network and the last layer of the remaining layers through the LM Head, optimizing their similarity via gradient descent.
From another perspective, the draft model can be seen as an alternative path that exits early from the inference process to directly reach the LM Head. However, directly skipping to the last layer often produces nonsensical tokens (this is normal, as the shallow network was not intended to directly connect to the final layer in the original pretraining), which is why an adapter is necessary to bridge the shallow sub-network's output to the LM Head.
In this way, we can fine-tune the shallow sub-network-based draft model to align more closely with the target model without compromising the original model.
To conclude, while I was unable to achieve acceleration with an 8B model during Self-Speculative Decoding experiments, the final Kangaroo training achieved an inference speed improvement of approximately 10% ~ 15%. Honestly, after the initial setbacks, seeing this result felt like a tearful victory for me XDDD.
Code Implementation
I wrote extensive code for Kangaroo implementation, which you can find on my GitHub: https://github.com/ccs96307/fast-llm-inference. It includes various implementations of inference acceleration techniques. Feel free to contact me for discussions.
Here, I will briefly introduce three scripts:
- modeling_kangaroo_llama3.py: The modified Kangaroo model based on the Llama architecture
- train.py: The training script
- test_kangaroo.py: The script to test Kangaroo's acceleration performance
Starting with Kangaroo's architecture implementation, I abstractly divided the LlamaForCausalLM model into shallow_layers
and remaining_layers
extracted from self.model.layers
. My shallow network only contains two layers based on the original paper's design.
For the Adapter, I initially implemented the Attention module as described in the paper and later extended it to a full decoder layer with MLP modules. This was due to convergence issues with just the Attention module, so I added more parameters in hopes of achieving better results.
Regarding the loss function, I introduced an additional hard_labels
loss term, which is weighted and combined with the soft-label cross-entropy loss from the original paper. This approach seemed slightly more effective in my experiments, though I couldn't replicate the paper's results fully—possibly due to some suboptimal implementation details.
I also implemented a prepare_dataset()
method compatible with the official Kangaroo project, extracting outputs from early exit layers and the final layer as training data. This optimization reduces intermediate inference during training. However, using the ShareGPT4 dataset required nearly 900GB of disk space across my servers, so I implemented both the optimized training method and the native approach of running both the draft model and target model.
from typing import Any, List, Optional, Tuple, Union
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from dataclasses import dataclass
import json
import torch
torch.set_default_dtype(torch.bfloat16)
from torch.nn.utils.rnn import pad_sequence
from transformers import LlamaForCausalLM
from transformers.models.llama.modeling_llama import Cache, DynamicCache, LlamaSdpaAttention, LlamaDecoderLayer, LlamaRMSNorm
from transformers.modeling_outputs import CausalLMOutputWithPast
from sampling.sampling import sample_next_token
from utils.utils import calculate_continuous_acceptance
@dataclass
class KangarooModelMode:
draft_only_mode: str = "draft_only"
target_only_mode: str = "target_only"
train_mode: str = "train"
@dataclass
class AdapterMode:
attention_only_mode: str = "attention_only"
mlp_only_mode: str = "mlp_only"
decoder_layer_mode: str = "decoder_layer"
class KangarooLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config):
super().__init__(config)
self.config = config
self.adapter_layer_mode = None
self.draft_mode_adapter_layer = None
self.shallow_layers = None
self.mode = KangarooModelMode.target_only_mode
self.confidence_threshold = 0.5
self.accept_rate = 0
self.total_accept_tokens = 0
self.total_draft_generated_token = 0
self.draft_temperature = 1.0
self.target_temperature = 1.0
self.alpha = 0.8
self.shallow_layer_num = 10
def set_skip_layer(self, shallow_layer_num: int) -> None:
self.shallow_layer_num = shallow_layer_num
self.shallow_layers = self.model.layers[:shallow_layer_num]
self.remaining_layers = self.model.layers[shallow_layer_num:]
def set_adapter_layer(self, _mode: str) -> None:
self.adapter_layer_mode = _mode
if _mode == AdapterMode.attention_only_mode:
self.attn_input_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.attn_output_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.draft_mode_adapter_layer = LlamaSdpaAttention(
config=self.config,
layer_idx=self.config.num_hidden_layers,
)
elif _mode == AdapterMode.decoder_layer_mode:
self.draft_mode_adapter_layer = LlamaDecoderLayer(
config=self.config,
layer_idx=self.config.num_hidden_layers,
)
def set_draft_mode(self) -> None:
self.mode = KangarooModelMode.draft_only_mode
def set_target_mode(self) -> None:
self.mode = KangarooModelMode.target_only_mode
def set_train_mode(self) -> None:
self.mode = KangarooModelMode.train_mode
def save_head(
self,
save_dir: str,
) -> None:
"""
Save the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num to the specified directory.
Args:
save_dir (str): Directory to save the adapter parameters.
"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Save the adapter's state dictionary
head_path = os.path.join(save_dir, "lm_head.pt")
torch.save(self.lm_head.state_dict(), head_path)
print(f"`lm_head` saved at {head_path}")
def save_norm(
self,
save_dir: str,
) -> None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Save the adapter's state dictionary
norm_path = os.path.join(save_dir, "norm.pt")
torch.save(self.model.norm.state_dict(), norm_path)
print(f"`norm` saved at {norm_path}")
def save_adapter(
self,
save_dir: str,
train_loss_history: List[float],
eval_loss_history: List[float],
) -> None:
"""
Save the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num to the specified directory.
Args:
save_dir (str): Directory to save the adapter parameters.
"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Save the adapter's state dictionary
adapter_path = os.path.join(save_dir, "draft_adapter.pt")
torch.save(self.draft_mode_adapter_layer.state_dict(), adapter_path)
print(f"Draft adapter saved at {adapter_path}")
# Save additional information (loss_history and shallow_layer_num)
metadata = {
"train_loss_history": train_loss_history,
"eval_loss_history": eval_loss_history,
"shallow_layer_num": self.shallow_layer_num
}
metadata_path = os.path.join(save_dir, "adapter_metadata.json")
with open(metadata_path, "w") as f:
json.dump(metadata, f)
print(f"Adapter metadata saved at {metadata_path}")
def load_adapter(self, load_dir: str) -> None:
"""
Load the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num from the specified directory.
Args:
load_dir (str): Directory to load the adapter parameters from.
Raises:
FileNotFoundError: If the adapter file does not exist in the specified directory.
"""
adapter_path = os.path.join(load_dir, "draft_adapter.pt")
if not os.path.exists(adapter_path):
raise FileNotFoundError(f"Draft adapter not found at {adapter_path}")
# Load the adapter's state dictionary
state_dict = torch.load(adapter_path, map_location=self.device)
self.draft_mode_adapter_layer.load_state_dict(state_dict=state_dict)
print(f"Draft adapter loaded from {adapter_path}")
def prepare_dataset(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if self.shallow_layers is None:
raise AttributeError(f"You do not set the `shallow_layers`!")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Model
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.model.gradient_checkpointing and self.model.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
# decoder layers
for decoder_layer in self.shallow_layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
shallow_hidden_states = hidden_states
# Remaining decoder layers
for decoder_layer in self.remaining_layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
return shallow_hidden_states, hidden_states
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if self.mode == KangarooModelMode.target_only_mode:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
**loss_kwargs,
)
elif self.mode == KangarooModelMode.draft_only_mode:
if self.shallow_layers is None:
raise AttributeError(f"You do not set the `shallow_layers`!")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Model
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.model.gradient_checkpointing and self.model.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.shallow_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
# Adapter
if self.adapter_layer_mode == AdapterMode.attention_only_mode:
residual = hidden_states
hidden_states = self.attn_input_norm(hidden_states)
hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
hidden_states = self.attn_output_norm(hidden_states)
elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
layer_outputs = self.draft_mode_adapter_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [torch.nn.functional.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=hidden_states,
attentions=all_self_attns,
)
elif self.mode == KangarooModelMode.train_mode:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.model.gradient_checkpointing and self.model.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.shallow_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Cache hidden states
remaining_hidden_states = hidden_states
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
# Adapter
if self.adapter_layer_mode == AdapterMode.attention_only_mode:
residual = hidden_states
hidden_states = self.attn_input_norm(hidden_states)
hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
hidden_states=hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
hidden_states = self.attn_output_norm(hidden_states)
elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
layer_outputs = self.draft_mode_adapter_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Remaining decoder layers
for decoder_layer in self.remaining_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
remaining_hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
remaining_hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
remaining_hidden_states = self.model.norm(remaining_hidden_states)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
draft_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) / self.draft_temperature
target_logits = self.lm_head(remaining_hidden_states[:, -num_logits_to_keep:, :]) / self.target_temperature
# Compute the log probabilities for both models
draft_log_probs = torch.nn.functional.log_softmax(draft_logits, dim=-1)
target_log_probs = torch.nn.functional.log_softmax(target_logits, dim=-1)
target_probs = torch.nn.functional.softmax(target_logits, dim=-1)
# Cross-entropy loss between target and draft model predictions
# kl_loss = torch.nn.functional.kl_div(draft_log_probs, target_probs, reduction="batchmean")
hard_labels = torch.argmax(target_probs, dim=-1)
soft_label_cross_entropy_loss = -(target_probs * draft_log_probs).sum(dim=-1).mean()
hard_label_loss = torch.nn.functional.cross_entropy(
draft_logits.view(-1, draft_logits.size(-1)), # Flatten logits
hard_labels.view(-1) # Flatten hard labels
)
loss = self.alpha * soft_label_cross_entropy_loss + (1 - self.alpha) * hard_label_loss
return CausalLMOutputWithPast(
loss=loss,
logits=target_logits,
past_key_values=past_key_values,
hidden_states=hidden_states,
attentions=all_self_attns,
)
@torch.no_grad()
def kangaroo_generate(
self,
eos_token_id: int,
pad_token_id: int,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
max_new_tokens: int = 100,
**kwargs,
) -> torch.LongTensor:
if self.shallow_layers is None:
raise AttributeError(f"You do not set the `shallow_layers`!")
confidence_score = 1.0
total_generate_tokens = 0
while total_generate_tokens < max_new_tokens:
draft_generate_tokens = 0
draft_probs = []
while confidence_score >= self.confidence_threshold:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.model.gradient_checkpointing and self.model.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.shallow_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Cache hidden states
remaining_hidden_states = hidden_states
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
# Adapter
if self.adapter_layer_mode == AdapterMode.attention_only_mode:
residual = hidden_states
hidden_states = self.attn_input_norm(hidden_states)
hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
hidden_states = self.attn_output_norm(hidden_states)
elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
layer_outputs = self.draft_mode_adapter_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Re-init
inputs_embeds = None
position_ids = None
cache_position = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
draft_logits = self.lm_head(hidden_states[:, -1:, :])
# Sampling and get the probabilities
next_tokens, probs = sample_next_token(
logits=draft_logits,
prefix_token_ids=input_ids,
)
draft_probs.append(probs)
input_ids = torch.cat([input_ids, next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([attention_mask, torch.ones(attention_mask.shape[0], 1).to(input_ids.device)], dim=-1)
draft_generate_tokens += 1
self.total_draft_generated_token += 1
# Support bs=1
decode_token_id = next_tokens[:, -1].item()
if probs[:, -1, decode_token_id] < self.confidence_threshold or total_generate_tokens + draft_generate_tokens >= max_new_tokens:
draft_probs = torch.cat(draft_probs, dim=1)
break
# Use whole model for evaluating
for decoder_layer in self.remaining_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
remaining_hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
remaining_hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
remaining_hidden_states = self.model.norm(remaining_hidden_states)
num_logits_to_keep = draft_probs.shape[1]
target_logits = self.lm_head(remaining_hidden_states[:, -num_logits_to_keep:, :])
target_input_ids = input_ids[:, :-1]
next_tokens, target_probs = sample_next_token(
logits=target_logits,
prefix_token_ids=target_input_ids,
probs_num=num_logits_to_keep,
)
# Evaluation
expanded_indices = input_ids[:, -draft_probs.shape[1]:].unsqueeze(-1)
# Get each probilities
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices).squeeze(-1)
selected_eval_probs = torch.gather(target_probs, dim=-1, index=expanded_indices).squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
# Concat `input_ids`
if torch.all(acceptance_mask):
total_generate_tokens += draft_generate_tokens
else:
new_input_ids = []
new_attention_mask = []
is_end = False
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1]
start_idx = input_ids.shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
total_generate_tokens += 1
if (acceptance_mask[batch_idx][pos_idx] and input_ids[batch_idx][start_idx+pos_idx].item() == eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
input_ids[batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(input_ids[batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(attention_mask[batch_idx][:start_idx+pos_idx+1])
is_end = input_ids[batch_idx][start_idx+pos_idx].item() == eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
self.total_accept_tokens += calculate_continuous_acceptance(acceptance_mask=acceptance_mask)
self.accept_rate = self.total_accept_tokens / self.total_draft_generated_token
if is_end:
break
return {"input_ids": input_ids}
For the training script, there’s nothing particularly noteworthy; it simply declares the dataset and trains the model in a standard manner.
from typing import Dict, List
import os
from dataclasses import dataclass
from datasets import load_dataset
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.model_selection import train_test_split
from kangaroo_modeling.modeling_kangaroo_llama3 import KangarooLlamaForCausalLM
class CustomDataset(Dataset):
def __init__(self, inputs: Dict[str, torch.LongTensor], device: torch.DeviceObjType):
self.inputs = inputs
self.device = device
def __len__(self) -> int:
return self.inputs.input_ids.shape[0]
def __getitem__(self, index: int):
return (
self.inputs.input_ids[index].to(self.device),
self.inputs.attention_mask[index].to(self.device),
)
def main() -> None:
# Settings
epochs = 100
batch_size = 4
max_length = 512
lr = 5e-5
shallow_layer_num = 2
adapter_mode = "attention_only"
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
pretrained_model_name_or_path = "../models/meta-llama--Meta-Llama-3.1-8B-Instruct"
# pretrained_model_name_or_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct"
model = KangarooLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model.set_skip_layer(shallow_layer_num=shallow_layer_num)
model.set_adapter_layer(_mode=adapter_mode)
model.set_train_mode()
model = model.to(device)
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
# Unfreeze adapter layer
model.draft_mode_adapter_layer.train()
for param in model.draft_mode_adapter_layer.parameters():
param.requires_grad = True
if hasattr(model, "attn_input_norm"):
print("Attention-Adapter!")
for param in model.attn_input_norm.parameters():
param.requires_grad = True
for param in model.attn_output_norm.parameters():
param.requires_grad = True
# Load dataset
dataset = load_dataset("shibing624/sharegpt_gpt4")
samples = dataset["train"]["conversations"]
samples = [[{"role": sample[0]["from"].replace("human", "user").replace("gpt", "assistant"), "content": sample[0]["value"]}] for sample in samples]
train_samples, eval_samples = train_test_split(samples, test_size=0.1, random_state=2999)
print(len(samples))
# Tokenized
train_inputs = tokenizer(
[tokenizer.apply_chat_template(messages, tokenize=False) for messages in train_samples],
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
eval_inputs = tokenizer(
[tokenizer.apply_chat_template(messages, tokenize=False) for messages in eval_samples],
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
train_dataset = CustomDataset(inputs=train_inputs, device=device)
eval_dataset = CustomDataset(inputs=eval_inputs, device=device)
# Dataloader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
# Optimizer
optimizer = torch.optim.AdamW(model.draft_mode_adapter_layer.parameters(), lr=lr)
# Training loop
for epoch in range(epochs):
model.train()
total_loss = 0
train_loss_history = []
eval_loss_history = []
for batch_idx, batch in enumerate(train_dataloader, 1):
input_ids, attention_mask = batch
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
num_logits_to_keep=max_length,
)
# Zero gradients
optimizer.zero_grad()
# Calculate loss
loss = outputs.loss
total_loss += loss.item()
train_loss_history.append(loss.item())
# Backward pass
loss.backward()
# Optimizer step
optimizer.step()
# Log training loss
avg_loss = total_loss / batch_idx
print(f"Train - Epoch [{epoch + 1}/{epochs}] Steps [{batch_idx}/{len(train_dataloader)}], Training Loss: {avg_loss:.4f}")
# Evaluate the model
model.eval()
eval_loss = 0
with torch.no_grad():
for batch_idx, batch in enumerate(eval_dataloader, 1):
input_ids, attention_mask = batch
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
num_logits_to_keep=max_length,
)
eval_loss += outputs.loss.item()
eval_loss_history.append(outputs.loss.item())
avg_loss = eval_loss / batch_idx
print(f"Eval - Epoch [{epoch + 1}/{epochs}] Steps [{batch_idx}/{len(eval_dataloader)}], Eval Loss: {avg_loss:.4f}")
# Save model checkpoint
save_dir = "./checkpoints/checkpoints_hce_attn_20241209/"
save_path = os.path.join(save_dir, f"epoch_{epoch+1}")
model.save_adapter(
save_path,
train_loss_history=train_loss_history,
eval_loss_history=eval_loss_history,
)
print(f"Adapter checkpoint saved at {save_path}")
if __name__ == "__main__":
main()
After training, I attempted to plot the loss curve:
As shown, the evaluation loss plateaued at approximately 0.88. In this case, we can proceed to test the speedup achieved by the Kangaroo architecture post-training.
from typing import Dict, List, Optional, Tuple
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import copy
import time
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from kangaroo_modeling.modeling_kangaroo_llama3 import KangarooLlamaForCausalLM
from sampling.sampling import sample_next_token
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_probs = []
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, torch.cat(draft_probs, dim=1)
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
diff_probs=draft_probs,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
def run_test() -> None:
# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Model path
pretrained_model_name_or_path = "../models/meta-llama--Meta-Llama-3.1-8B-Instruct"
adapter_dir = "checkpoints/checkpoints_hce_decoder_layer_20241205/epoch_45/"
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
# Load Model
model = KangarooLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
model.set_skip_layer(shallow_layer_num=2)
model.set_adapter_layer("decoder_layer")
if adapter_dir:
model.load_adapter(adapter_dir)
model = model.to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": "What is the capital of Taiwan. And why?",
},
],
]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Warm up the model (CUDA)
inputs_dummy = {k: v.clone() for k, v in inputs.items()}
with torch.no_grad():
model.set_draft_mode()
model(**inputs_dummy)
model.set_target_mode()
model(**inputs_dummy)
torch.cuda.synchronize()
# Record
raw_inputs = copy.deepcopy(inputs)
raw_token_num = raw_inputs["input_ids"].shape[1]
total_draft_tokens = 0
total_accept_tokens = 0
gamma = 1
max_new_tokens = 100
is_end = False
start_time = time.time()
while not is_end:
# Draft model
model.set_draft_mode()
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=model,
draft_tokenizer=tokenizer,
inputs=inputs,
gamma=gamma,
temperature=0,
)
total_draft_tokens += gamma
# Target model
model.set_target_mode()
outputs, is_end, accept_tokens = target_speculative_decode(
target_model=model,
target_tokenizer=tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
temperature=1,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
break
generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
spent_time = time.time() - start_time
print(f"Generate token number: {generate_token_num}")
print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
# Normal Target Model Speed
inputs = copy.deepcopy(raw_inputs)
start_time = time.time()
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=model,
draft_tokenizer=tokenizer,
inputs=inputs,
gamma=max_new_tokens,
)
spent_time = time.time() - start_time
print(f"Generate token number: {max_new_tokens}")
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")
if __name__ == "__main__":
run_test()
Output:
Generate token number: 100
Generate speed: 41.15903447942894 tokens/sec
Speculative Decoding Spent Time: 2.429600238800049 seconds.
Accept Rate: 0.2987012987012987
Generate token number: 100
Generate speed: 35.865311488261504 tokens/sec
Normal Target Model Decoding Spent Time: 2.7882094383239746 seconds.
This shows an approximate speedup of (41.15 - 35.86) / 35.86 ~= 14.75%.
References
- Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting
- GitHub - Kangaroo: Lossless Self-Speculative Decoding via Double ...