Last Updated on 2024-12-10 by Clay
前言
Kangaroo 是一種引入了可訓練的適配器(Adapter)層的 Self-Speculative Decoding 實現,我最近幾週都在嘗試微調其 Adapter,有了一些初步成果,故紀錄於此。
Kangaroo 架構
Kangaroo 是一種 Self-Speculative Decoding 的變體,而 Self-Speculative Decoding 則又是 Speculative Decoding 的一個分支版本。
在原本的 Speculative Decoding 加速中,我們會有一個 draft model 和一個 target model(最終希望加速的模型),通常 draft model 跟 target model 享有共同的詞彙表,以及快得多的推理速度;之後,由 draft model 推測性地產生出數個 Tokens,再由 target model 一次性驗證。
在原始論文(Fast Inference from Transformers via Speculative Decoding)中,研究團隊有針對 target model 的驗證演算法特別設計,以保證通過 draft model 推測解碼後能夠與原始模型 target model 保持同樣的機率分佈輸出,達成不影響原始模型解碼結果的效果。
而 Self-Speculative Decoding(Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding)則是透過 target model 自身的部份網路當作 draft model。簡單來說,我們在 Transformer Block 中每一層都會有注意力層和 MLP 層,並且我們可以選擇是否要跳過這些層。
這些層是否要選擇跳過是有講究的,一來跳得太少,draft model 的速度提昇不了,也就沒辦法替 target model 加速;反之,若是我們跳得太多,draft model 可能推測的解碼就會與 target model 偏離非常遠,可能根本沒辦法被 target model 接受。
所以為了兼顧跳層的推理速度與正確性,Self-Speculative Decoding 採用了貝氏優化的方式去探索究竟怎麼樣的跳層策略能夠滿足以上兩項需求。
不過,當初的實驗結果可說是慘不忍睹。由於我的硬體配置限制,我僅僅只能測試到 9B 量級為止,不過實際上根本加速不了 —— 難怪當初論文放的結果都至少從 13B 量級開始,還真有可能是因為模型量級太少導致用部份神經網路組成的 draft model 效果不佳。
於是我就又開始挑戰 Kangaroo 的實作了。Kangaroo 的精神與 Self-Speculative Decoding 一致,但是它是允許模型進行『訓練』的。當然如果我們訓練了 draft model,我們就不能保證 target model 的性能了,畢竟 draft model 也是 target model 的一部分,有可能會導致 target model 在微調後失去部份能力。
所以 Kangaroo 提出了以下架構:
Kangaroo 的 draft model 基本會由淺層網路(shallow sub-network)組成,再通過額外添加可訓練的 Adapter-Network 得到隱藏狀態輸出,最後再把隱藏狀態輸出傳入 LM Head 取得最後的 logits。
而這個 Adapter-Network 是可以訓練的,並且在實際訓練中,我們會拿 Adapter-Network 和 Remaining Layers 最後一層的輸出都通過 LM Head 取得詞彙表尺存的 logits 輸出,並藉由梯度下降讓兩者的 logits 輸出越像越好。
換個方向來想,也可以想像成 draft model 是提早離開模型推理流程而直接跳到最後 LM Head 的另外一條路徑,然而直接跳到最後一層會導致 LLM 解碼出奇怪的 Token(這很正常,畢竟還是淺層網路而已,在原始 pretraining 中根本沒預計會讓淺層網路結果直接接到最後一層),所以我們還需要額外有一個 Adapter 來作為淺層網路輸出和 LM Head 之間的橋接器。
這樣一來,我們就可以在不影響原始模型的情況下,將由淺層網路組成的 draft model 微調得更接近 target model。
先講結論,在 Self-Speculative Decoding 的實驗中我怎麼樣都沒辦法讓 8B 量級的模型進行加速、但是在 Kangaroo 訓練的最後,我總算是透過 Kangaroo 的方法提昇了約 10% ~ 15% 的加速推理了。老實說經歷了 Self-Speculative Decoding 和 Kangaroo 早期實驗的挫敗,最後能看到這個結果已經足夠我老淚縱橫了 XDDD
實作原始碼
我關於 Kangaroo 實作的程式碼還挺多的,詳細的實作可以參考我的 GitHub: https://github.com/ccs96307/fast-llm-inference,這裡放有我各式各樣加速推理的實現,也歡迎隨時跟我聯絡討論。
在這裡,我就簡單介紹 3 個腳本:
- modeling_kangaroo_llama3.py: 這個是我基於 Llama 架構所修改的 Kangaroo 模型
- train.py: 訓練用的腳本
- test_kangaroo.py: 測試 Kangaroo 模型帶來的加速效果
首先是 Kangaroo 的架構實現,這裡我先抽象地描述:我繼承了 LlamaForCausalLM 這一模型,並從原本的 self.model.layers 中分別抽出了 shallow_layers 和 remaining_layers,按照原始論文設計,我的淺層網路實際上只有兩層。
而在 Adapter 的部份我則暫時是實現了 Attention 模組的版本(論文中的原始方法)以及添加上了 MLP 模組的完整 decoder layer 模組的版本。這是因為我本來在 Attention 上訓練效果收斂得不好,所以加上更多參數希望能得到更好的結果。
而在 loss function 的設計上,我這邊也額外多了一個 hard_labels
的損失,並且可以藉由乘上一個權重與本來論文中的 soft labels 版本的 cross-entropy 同時使用。在我的實驗中,這樣的效果似乎更好一點,不過實際上我並沒能復現到原始論文中那麼好的結果,可能有哪邊還有潛在的實現不夠好的部份。
另外值得一提的是,我有實作一個 prepare_dataset() 的方法,這個方法是搭配 Kangaroo 官方專案的實作,只抽取 early exit layer 和最後一層的輸出作為訓練資料,以節省模型訓練時的中間推理過程的優化;不過在我的許多台 server 上,使用 sharegpt4 資料集會需要在硬碟儲存接近 900 GB 的空間,所以我優化的訓練方法、跟原生直接把 draft model 和 target model 都跑一遍的兩個方法都寫出來了。
from typing import Any, List, Optional, Tuple, Union
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from dataclasses import dataclass
import json
import torch
torch.set_default_dtype(torch.bfloat16)
from torch.nn.utils.rnn import pad_sequence
from transformers import LlamaForCausalLM
from transformers.models.llama.modeling_llama import Cache, DynamicCache, LlamaSdpaAttention, LlamaDecoderLayer, LlamaRMSNorm
from transformers.modeling_outputs import CausalLMOutputWithPast
from sampling.sampling import sample_next_token
from utils.utils import calculate_continuous_acceptance
@dataclass
class KangarooModelMode:
draft_only_mode: str = "draft_only"
target_only_mode: str = "target_only"
train_mode: str = "train"
@dataclass
class AdapterMode:
attention_only_mode: str = "attention_only"
mlp_only_mode: str = "mlp_only"
decoder_layer_mode: str = "decoder_layer"
class KangarooLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config):
super().__init__(config)
self.config = config
self.adapter_layer_mode = None
self.draft_mode_adapter_layer = None
self.shallow_layers = None
self.mode = KangarooModelMode.target_only_mode
self.confidence_threshold = 0.5
self.accept_rate = 0
self.total_accept_tokens = 0
self.total_draft_generated_token = 0
self.draft_temperature = 1.0
self.target_temperature = 1.0
self.alpha = 0.8
self.shallow_layer_num = 10
def set_skip_layer(self, shallow_layer_num: int) -> None:
self.shallow_layer_num = shallow_layer_num
self.shallow_layers = self.model.layers[:shallow_layer_num]
self.remaining_layers = self.model.layers[shallow_layer_num:]
def set_adapter_layer(self, _mode: str) -> None:
self.adapter_layer_mode = _mode
if _mode == AdapterMode.attention_only_mode:
self.attn_input_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.attn_output_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
self.draft_mode_adapter_layer = LlamaSdpaAttention(
config=self.config,
layer_idx=self.config.num_hidden_layers,
)
elif _mode == AdapterMode.decoder_layer_mode:
self.draft_mode_adapter_layer = LlamaDecoderLayer(
config=self.config,
layer_idx=self.config.num_hidden_layers,
)
def set_draft_mode(self) -> None:
self.mode = KangarooModelMode.draft_only_mode
def set_target_mode(self) -> None:
self.mode = KangarooModelMode.target_only_mode
def set_train_mode(self) -> None:
self.mode = KangarooModelMode.train_mode
def save_head(
self,
save_dir: str,
) -> None:
"""
Save the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num to the specified directory.
Args:
save_dir (str): Directory to save the adapter parameters.
"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Save the adapter's state dictionary
head_path = os.path.join(save_dir, "lm_head.pt")
torch.save(self.lm_head.state_dict(), head_path)
print(f"`lm_head` saved at {head_path}")
def save_norm(
self,
save_dir: str,
) -> None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Save the adapter's state dictionary
norm_path = os.path.join(save_dir, "norm.pt")
torch.save(self.model.norm.state_dict(), norm_path)
print(f"`norm` saved at {norm_path}")
def save_adapter(
self,
save_dir: str,
train_loss_history: List[float],
eval_loss_history: List[float],
) -> None:
"""
Save the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num to the specified directory.
Args:
save_dir (str): Directory to save the adapter parameters.
"""
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# Save the adapter's state dictionary
adapter_path = os.path.join(save_dir, "draft_adapter.pt")
torch.save(self.draft_mode_adapter_layer.state_dict(), adapter_path)
print(f"Draft adapter saved at {adapter_path}")
# Save additional information (loss_history and shallow_layer_num)
metadata = {
"train_loss_history": train_loss_history,
"eval_loss_history": eval_loss_history,
"shallow_layer_num": self.shallow_layer_num
}
metadata_path = os.path.join(save_dir, "adapter_metadata.json")
with open(metadata_path, "w") as f:
json.dump(metadata, f)
print(f"Adapter metadata saved at {metadata_path}")
def load_adapter(self, load_dir: str) -> None:
"""
Load the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num from the specified directory.
Args:
load_dir (str): Directory to load the adapter parameters from.
Raises:
FileNotFoundError: If the adapter file does not exist in the specified directory.
"""
adapter_path = os.path.join(load_dir, "draft_adapter.pt")
if not os.path.exists(adapter_path):
raise FileNotFoundError(f"Draft adapter not found at {adapter_path}")
# Load the adapter's state dictionary
state_dict = torch.load(adapter_path, map_location=self.device)
self.draft_mode_adapter_layer.load_state_dict(state_dict=state_dict)
print(f"Draft adapter loaded from {adapter_path}")
def prepare_dataset(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if self.shallow_layers is None:
raise AttributeError(f"You do not set the `shallow_layers`!")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Model
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.model.gradient_checkpointing and self.model.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
# decoder layers
for decoder_layer in self.shallow_layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
shallow_hidden_states = hidden_states
# Remaining decoder layers
for decoder_layer in self.remaining_layers:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
return shallow_hidden_states, hidden_states
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
if self.mode == KangarooModelMode.target_only_mode:
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
**loss_kwargs,
)
elif self.mode == KangarooModelMode.draft_only_mode:
if self.shallow_layers is None:
raise AttributeError(f"You do not set the `shallow_layers`!")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Model
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.model.gradient_checkpointing and self.model.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.shallow_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
# Adapter
if self.adapter_layer_mode == AdapterMode.attention_only_mode:
residual = hidden_states
hidden_states = self.attn_input_norm(hidden_states)
hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
hidden_states = self.attn_output_norm(hidden_states)
elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
layer_outputs = self.draft_mode_adapter_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [torch.nn.functional.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
hidden_states=hidden_states,
attentions=all_self_attns,
)
elif self.mode == KangarooModelMode.train_mode:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.model.gradient_checkpointing and self.model.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.shallow_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Cache hidden states
remaining_hidden_states = hidden_states
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
# Adapter
if self.adapter_layer_mode == AdapterMode.attention_only_mode:
residual = hidden_states
hidden_states = self.attn_input_norm(hidden_states)
hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
hidden_states=hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
hidden_states = self.attn_output_norm(hidden_states)
elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
layer_outputs = self.draft_mode_adapter_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Remaining decoder layers
for decoder_layer in self.remaining_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
remaining_hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
remaining_hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
remaining_hidden_states = self.model.norm(remaining_hidden_states)
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
draft_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) / self.draft_temperature
target_logits = self.lm_head(remaining_hidden_states[:, -num_logits_to_keep:, :]) / self.target_temperature
# Compute the log probabilities for both models
draft_log_probs = torch.nn.functional.log_softmax(draft_logits, dim=-1)
target_log_probs = torch.nn.functional.log_softmax(target_logits, dim=-1)
target_probs = torch.nn.functional.softmax(target_logits, dim=-1)
# Cross-entropy loss between target and draft model predictions
# kl_loss = torch.nn.functional.kl_div(draft_log_probs, target_probs, reduction="batchmean")
hard_labels = torch.argmax(target_probs, dim=-1)
soft_label_cross_entropy_loss = -(target_probs * draft_log_probs).sum(dim=-1).mean()
hard_label_loss = torch.nn.functional.cross_entropy(
draft_logits.view(-1, draft_logits.size(-1)), # Flatten logits
hard_labels.view(-1) # Flatten hard labels
)
loss = self.alpha * soft_label_cross_entropy_loss + (1 - self.alpha) * hard_label_loss
return CausalLMOutputWithPast(
loss=loss,
logits=target_logits,
past_key_values=past_key_values,
hidden_states=hidden_states,
attentions=all_self_attns,
)
@torch.no_grad()
def kangaroo_generate(
self,
eos_token_id: int,
pad_token_id: int,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
max_new_tokens: int = 100,
**kwargs,
) -> torch.LongTensor:
if self.shallow_layers is None:
raise AttributeError(f"You do not set the `shallow_layers`!")
confidence_score = 1.0
total_generate_tokens = 0
while total_generate_tokens < max_new_tokens:
draft_generate_tokens = 0
draft_probs = []
while confidence_score >= self.confidence_threshold:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.model.gradient_checkpointing and self.model.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.model.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.shallow_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Cache hidden states
remaining_hidden_states = hidden_states
# Add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
# Adapter
if self.adapter_layer_mode == AdapterMode.attention_only_mode:
residual = hidden_states
hidden_states = self.attn_input_norm(hidden_states)
hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
hidden_states = self.attn_output_norm(hidden_states)
elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
layer_outputs = self.draft_mode_adapter_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
hidden_states = self.model.norm(hidden_states)
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# Re-init
inputs_embeds = None
position_ids = None
cache_position = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
draft_logits = self.lm_head(hidden_states[:, -1:, :])
# Sampling and get the probabilities
next_tokens, probs = sample_next_token(
logits=draft_logits,
prefix_token_ids=input_ids,
)
draft_probs.append(probs)
input_ids = torch.cat([input_ids, next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([attention_mask, torch.ones(attention_mask.shape[0], 1).to(input_ids.device)], dim=-1)
draft_generate_tokens += 1
self.total_draft_generated_token += 1
# Support bs=1
decode_token_id = next_tokens[:, -1].item()
if probs[:, -1, decode_token_id] < self.confidence_threshold or total_generate_tokens + draft_generate_tokens >= max_new_tokens:
draft_probs = torch.cat(draft_probs, dim=1)
break
# Use whole model for evaluating
for decoder_layer in self.remaining_layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = decoder_layer(
remaining_hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
remaining_hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
remaining_hidden_states = self.model.norm(remaining_hidden_states)
num_logits_to_keep = draft_probs.shape[1]
target_logits = self.lm_head(remaining_hidden_states[:, -num_logits_to_keep:, :])
target_input_ids = input_ids[:, :-1]
next_tokens, target_probs = sample_next_token(
logits=target_logits,
prefix_token_ids=target_input_ids,
probs_num=num_logits_to_keep,
)
# Evaluation
expanded_indices = input_ids[:, -draft_probs.shape[1]:].unsqueeze(-1)
# Get each probilities
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices).squeeze(-1)
selected_eval_probs = torch.gather(target_probs, dim=-1, index=expanded_indices).squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
# Concat `input_ids`
if torch.all(acceptance_mask):
total_generate_tokens += draft_generate_tokens
else:
new_input_ids = []
new_attention_mask = []
is_end = False
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1]
start_idx = input_ids.shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
total_generate_tokens += 1
if (acceptance_mask[batch_idx][pos_idx] and input_ids[batch_idx][start_idx+pos_idx].item() == eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
input_ids[batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(input_ids[batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(attention_mask[batch_idx][:start_idx+pos_idx+1])
is_end = input_ids[batch_idx][start_idx+pos_idx].item() == eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
self.total_accept_tokens += calculate_continuous_acceptance(acceptance_mask=acceptance_mask)
self.accept_rate = self.total_accept_tokens / self.total_draft_generated_token
if is_end:
break
return {"input_ids": input_ids}
而在訓練腳本中,就沒有太多值得細講的部份,就只是很正規地宣告了資料集與模型進行訓練。
from typing import Dict, List
import os
from dataclasses import dataclass
from datasets import load_dataset
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.model_selection import train_test_split
from kangaroo_modeling.modeling_kangaroo_llama3 import KangarooLlamaForCausalLM
class CustomDataset(Dataset):
def __init__(self, inputs: Dict[str, torch.LongTensor], device: torch.DeviceObjType):
self.inputs = inputs
self.device = device
def __len__(self) -> int:
return self.inputs.input_ids.shape[0]
def __getitem__(self, index: int):
return (
self.inputs.input_ids[index].to(self.device),
self.inputs.attention_mask[index].to(self.device),
)
def main() -> None:
# Settings
epochs = 100
batch_size = 4
max_length = 512
lr = 5e-5
shallow_layer_num = 2
adapter_mode = "attention_only"
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
pretrained_model_name_or_path = "../models/meta-llama--Meta-Llama-3.1-8B-Instruct"
# pretrained_model_name_or_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct"
model = KangarooLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model.set_skip_layer(shallow_layer_num=shallow_layer_num)
model.set_adapter_layer(_mode=adapter_mode)
model.set_train_mode()
model = model.to(device)
# Freeze all parameters
for param in model.parameters():
param.requires_grad = False
# Unfreeze adapter layer
model.draft_mode_adapter_layer.train()
for param in model.draft_mode_adapter_layer.parameters():
param.requires_grad = True
if hasattr(model, "attn_input_norm"):
print("Attention-Adapter!")
for param in model.attn_input_norm.parameters():
param.requires_grad = True
for param in model.attn_output_norm.parameters():
param.requires_grad = True
# Load dataset
dataset = load_dataset("shibing624/sharegpt_gpt4")
samples = dataset["train"]["conversations"]
samples = [[{"role": sample[0]["from"].replace("human", "user").replace("gpt", "assistant"), "content": sample[0]["value"]}] for sample in samples]
train_samples, eval_samples = train_test_split(samples, test_size=0.1, random_state=2999)
print(len(samples))
# Tokenized
train_inputs = tokenizer(
[tokenizer.apply_chat_template(messages, tokenize=False) for messages in train_samples],
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
eval_inputs = tokenizer(
[tokenizer.apply_chat_template(messages, tokenize=False) for messages in eval_samples],
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
train_dataset = CustomDataset(inputs=train_inputs, device=device)
eval_dataset = CustomDataset(inputs=eval_inputs, device=device)
# Dataloader
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)
# Optimizer
optimizer = torch.optim.AdamW(model.draft_mode_adapter_layer.parameters(), lr=lr)
# Training loop
for epoch in range(epochs):
model.train()
total_loss = 0
train_loss_history = []
eval_loss_history = []
for batch_idx, batch in enumerate(train_dataloader, 1):
input_ids, attention_mask = batch
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
num_logits_to_keep=max_length,
)
# Zero gradients
optimizer.zero_grad()
# Calculate loss
loss = outputs.loss
total_loss += loss.item()
train_loss_history.append(loss.item())
# Backward pass
loss.backward()
# Optimizer step
optimizer.step()
# Log training loss
avg_loss = total_loss / batch_idx
print(f"Train - Epoch [{epoch + 1}/{epochs}] Steps [{batch_idx}/{len(train_dataloader)}], Training Loss: {avg_loss:.4f}")
# Evaluate the model
model.eval()
eval_loss = 0
with torch.no_grad():
for batch_idx, batch in enumerate(eval_dataloader, 1):
input_ids, attention_mask = batch
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
num_logits_to_keep=max_length,
)
eval_loss += outputs.loss.item()
eval_loss_history.append(outputs.loss.item())
avg_loss = eval_loss / batch_idx
print(f"Eval - Epoch [{epoch + 1}/{epochs}] Steps [{batch_idx}/{len(eval_dataloader)}], Eval Loss: {avg_loss:.4f}")
# Save model checkpoint
save_dir = "./checkpoints/checkpoints_hce_attn_20241209/"
save_path = os.path.join(save_dir, f"epoch_{epoch+1}")
model.save_adapter(
save_path,
train_loss_history=train_loss_history,
eval_loss_history=eval_loss_history,
)
print(f"Adapter checkpoint saved at {save_path}")
if __name__ == "__main__":
main()
在訓練結束後,我嘗試繪製出我的 loss:
可以看到 Eval Loss 其實在約 0.88 左右就一直降不下去了,而在這樣的情況下,我們可以來實際測試 Kangaroo 的架構在訓練後究竟加速多少。
from typing import Dict, List, Optional, Tuple
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import copy
import time
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from kangaroo_modeling.modeling_kangaroo_llama3 import KangarooLlamaForCausalLM
from sampling.sampling import sample_next_token
def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
continuous_acceptance = 0
for accepted in acceptance_mask.long().squeeze(0):
if accepted == 1:
continuous_acceptance += 1
else:
break
return continuous_acceptance
def drafter_speculative_decode(
draft_model: torch.nn.Module,
draft_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
gamma: int = 10,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
draft_probs = []
for idx in range(gamma):
with torch.no_grad():
outputs = draft_model(**inputs)
next_tokens, probs = sample_next_token(
logits=outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
draft_probs.append(probs)
input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, torch.cat(draft_probs, dim=1)
def target_speculative_decode(
target_model: torch.nn.Module,
target_tokenizer: PreTrainedTokenizerBase,
inputs: Dict[str, torch.Tensor],
draft_probs: torch.FloatTensor,
temperature: float = 1.0,
top_k: Optional[int] = 0, # Default is 0, it means do not select top-k tokens
top_p: Optional[float] = 1.0,
repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
with torch.no_grad():
outputs = target_model(**inputs)
next_tokens, target_probs = sample_next_token(
logits=outputs.logits,
diff_probs=draft_probs,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
probs_num=draft_probs.shape[1] + 1,
)
next_token = next_tokens[:, -1:]
# Evaluation
indices = inputs["input_ids"][:, -draft_probs.shape[1]:]
eval_probs = target_probs[:, :-1, :]
expanded_indices = indices.unsqueeze(-1)
selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
selected_draft_probs = selected_draft_probs.squeeze(-1)
selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
selected_eval_probs = selected_eval_probs.squeeze(-1)
# Compare draft_prob and eval_prob, and check the reject_mask
mask_to_reject = selected_draft_probs > selected_eval_probs
# Calculate reject probabilty 1 - (eval_prob / draft_prob)
rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)
# Generate random values to determined accept or reject
random_values = torch.rand_like(rejection_probs)
rejection_decisions = random_values < rejection_probs
# Get the final reject masks
rejection_masks = mask_to_reject & rejection_decisions
acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
acceptance_mask[rejection_masks] = False
is_end = False
# Concat `input_ids`
if torch.all(acceptance_mask):
input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
else:
new_input_ids = []
new_attention_mask = []
for batch_idx in range(next_tokens.shape[0]):
gamma = next_tokens.shape[1] - 1
start_idx = inputs["input_ids"].shape[1] - gamma
for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
break
input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)
def run_test() -> None:
# Device
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
# Model path
pretrained_model_name_or_path = "../models/meta-llama--Meta-Llama-3.1-8B-Instruct"
adapter_dir = "checkpoints/checkpoints_hce_decoder_layer_20241205/epoch_45/"
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
# Load Model
model = KangarooLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
model.set_skip_layer(shallow_layer_num=2)
model.set_adapter_layer("decoder_layer")
if adapter_dir:
model.load_adapter(adapter_dir)
model = model.to(device)
# Tokenize
messages = [
[
{
"role": "user",
"content": "What is the capital of Taiwan. And why?",
},
],
]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
# Warm up the model (CUDA)
inputs_dummy = {k: v.clone() for k, v in inputs.items()}
with torch.no_grad():
model.set_draft_mode()
model(**inputs_dummy)
model.set_target_mode()
model(**inputs_dummy)
torch.cuda.synchronize()
# Record
raw_inputs = copy.deepcopy(inputs)
raw_token_num = raw_inputs["input_ids"].shape[1]
total_draft_tokens = 0
total_accept_tokens = 0
gamma = 1
max_new_tokens = 100
is_end = False
start_time = time.time()
while not is_end:
# Draft model
model.set_draft_mode()
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=model,
draft_tokenizer=tokenizer,
inputs=inputs,
gamma=gamma,
temperature=0,
)
total_draft_tokens += gamma
# Target model
model.set_target_mode()
outputs, is_end, accept_tokens = target_speculative_decode(
target_model=model,
target_tokenizer=tokenizer,
inputs=target_inputs,
draft_probs=draft_probs,
temperature=1,
)
total_accept_tokens += accept_tokens
inputs = outputs
if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
break
generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
spent_time = time.time() - start_time
print(f"Generate token number: {generate_token_num}")
print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")
# Normal Target Model Speed
inputs = copy.deepcopy(raw_inputs)
start_time = time.time()
target_inputs, draft_probs = drafter_speculative_decode(
draft_model=model,
draft_tokenizer=tokenizer,
inputs=inputs,
gamma=max_new_tokens,
)
spent_time = time.time() - start_time
print(f"Generate token number: {max_new_tokens}")
print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")
if __name__ == "__main__":
run_test()
Output:
Generate token number: 100
Generate speed: 41.15903447942894 tokens/sec
Speculative Decoding Spent Time: 2.429600238800049 seconds.
Accept Rate: 0.2987012987012987
Generate token number: 100
Generate speed: 35.865311488261504 tokens/sec
Normal Target Model Decoding Spent Time: 2.7882094383239746 seconds.
可以看到有約 (41.15 - 35.86) / 35.86 ~= 14.75% 的加速。
References
- Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting
- GitHub - Kangaroo: Lossless Self-Speculative Decoding via Double ...