Last Updated on 2024-11-10 by Clay
介紹
自推測性解碼(Self-Speculative Decoding)是一個推測性解碼(Speculative Decoding)的變體。原本的 Speculative Decoding 是採用一個草稿模型(draft model)來優化我們真正想要推理的目標模型(target),並且 draft model 擁有與 target model 相似的輸出以及快上幾倍的推理時間,通常是由 target model 蒸餾而來。
一旦 draft model 推測出眾多待驗證的候選詞元序列後,就會由 target model 預測這個候選詞元序列的下一個 Token,並同時得到 target model 對於 draft model 所推測的先前那些 Tokens 各自的機率分佈;之後,就可以以不同的驗證演算法決定是否接受 —— 若是我們接受了許多由 draft model 解碼的 Tokens,則意味著我們的 target model 在一次推理中解碼了許多 Tokens,以此達成加速目的。
而 Self-Speculative Decoding 考量了額外載入 draft model 的 VRAM 開銷,只使用自己的部份神經網路層進行推理來模擬成一個 draft model,再用全部的神經網路層當作 target model 進行驗證,詳細內容可以參考論文 Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding。
而我今天紀錄的,正是如何透過簡單地呼叫 Transformer 的原始碼,並簡單地添加跳過機制,以此完成一個可以跳過自注意力機制(Self-Attention Mechanism)層和 MLP 層的 LayerSkipModel。
不過,由於我還沒有完成論文中使用貝氏優化(Bayesian Optimization)來挑選跳過層的實驗機制,所以還無法附上結果。
我先前有一些關於 Speculative Decoding 的簡單介紹,可以參考:
- [論文閱讀] Fast Inference from Transformers via Speculative Decoding
- 推測性解碼(Speculative Decoding)實作筆記(附簡易實驗結果)
同時,我也會把我的實現放在 GitHub: https://github.com/ccs96307/fast-llm-inference,這是一個收錄我各種簡易實現加速推理的專案,同時會附上我參考的論文連結,歡迎收藏或 star,歡迎大家前往哦!
跳層架構實作
基本上,我們只需要重新寫一次 LlamaDecoderLayer
、LlamaModel
、LlamaForCausalLM
這三個類別(classes)即可,大部分程式碼我都是直接複製 HuggingFace Transformers 的實現,所以版權歸他們所有,嗯沒錯。
當然你可以用繼承的方式去寫,但對我來說,維持在這一層的抽象程度剛剛好,要添加的程式碼也最少。
簡單來說我做了三件事:
- 把三個類別加上
LayerSkip
前綴定義好,並在初始化時明確定義self.draft_mode
和self.skip_layer_ids
,並把初始化的模組設定好(LayerSkipLlamaForCausalLM 裡面初始化的模型是LayerSkipLlamaModel
,LayerSkipLlamaModel
裡面初始化的解碼層是LayerSkipLlamaDecoderLayer
) - 把三個類別的內部方法添加
set_skip_layer_ids()
和set_draft_mode()
兩種,並且會從最外層的LayerSkipLlamaForCausalLM
設定好後一路往內部傳遞到LayerSkipLlamaDecoderLayer
為止 - 在
LayerSkipLlamaDecoderLayer
的forward()
方法中,根據self.draft_mode
是否啟用、以及self.skip_layer_ids
是否有包含當前解碼層,來決定是否跳過注意力機制或是 MLP
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import (
LlamaConfig,
LlamaPreTrainedModel,
GenerationMixin,
)
from transformers.models.llama.modeling_llama import (
LlamaMLP,
LlamaRMSNorm,
LlamaRotaryEmbedding,
LLAMA_ATTENTION_CLASSES,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
class LayerSkipLlamaDecoderLayer(torch.nn.Module):
def __init__(
self,
config: LlamaConfig,
layer_idx: int,
):
super().__init__()
# Set skip layer
skip_layer_ids = {"attn": [], "mlp": []}
self.draft_mode = False
self.skip_layer_ids = skip_layer_ids
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
self.skip_layer_ids = skip_layer_ids
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
if self.draft_mode and self.layer_idx in self.skip_layer_ids["attn"]:
hidden_states = residual
self_attn_weights = None
present_key_value = None
else:
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
if self.draft_mode and self.layer_idx in self.skip_layer_ids["mlp"]:
hidden_states = residual
else:
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
class LayerSkipLlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
# Set skip layer
skip_layer_ids = {"attn": [], "mlp": []}
self.draft_mode = False
self.skip_layer_ids = skip_layer_ids
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = torch.nn.ModuleList(
[LayerSkipLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
self.skip_layer_ids = skip_layer_ids
for layer in self.layers:
layer.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
for layer in self.layers:
layer.set_draft_mode(_mode=_mode)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class LayerSkipLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
# Set skip layer
skip_layer_ids = {"attn": [], "mlp": []}
self.draft_mode = False
self.skip_layer_ids = skip_layer_ids
self.model = LayerSkipLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
assert "attn" in skip_layer_ids and "mlp" in skip_layer_ids, "`skip_layer_ids` need to be set `attn` and `mlp`!"
assert isinstance(skip_layer_ids["attn"], list), "`skip_layer_ids['attn']` need to be a list!"
assert isinstance(skip_layer_ids["mlp"], list), "`skip_layer_ids['mlp']` need to be a list!"
for attn_layer_idx in skip_layer_ids["attn"]:
assert attn_layer_idx < len(self.model.layers), f"attn_layer_idx {attn_layer_idx} is out of Range ({len(self.model.layers)})"
for mlp_layer_idx in skip_layer_ids["mlp"]:
assert mlp_layer_idx < len(self.model.layers), f"mlp_layer_idx {mlp_layer_idx} is out of Range ({len(self.model.layers)})"
self.skip_layer_ids = skip_layer_ids
self.model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
print("skip_layer_ids:", self.skip_layer_ids)
def set_draft_mode(self, _mode: bool):
self.draft_mode = _mode
self.model.set_draft_mode(_mode=_mode)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
完成後,我們來測試看看吧!
測試 LayerSkip 模型的效果
不知道是不是因為我選用的模型是 1.7B 量級的,總感覺只跳了幾層,性能損失得非常非常快啊 —— 或許這就是 Meta AI 他們後來的 LayerSkip 模型必定會經過訓練階段的原因之一。
不過,根據我粗淺的計算,LayerSkip 確實還是會節省時間的!(這也是當然,畢竟有些計算直接省略了)
總之,目前的進度已經來到若是我決定好要跳過哪些層,就能測試 Self-Speculative Decoding 效果的階段了,不過我還需要實作一下我的跳過層檢索方式。
以下是個簡單的測試,比較原始模型和跳過幾層的模型差別。這裡我定義跳過的層,分別是注意力機制層和 MLP 層的第 2、15、18 層。
import time
import torch
from transformers import AutoTokenizer
from layerskip_modeling.modeling_layerskip_llama import LayerSkipLlamaForCausalLM
if __name__ == "__main__":
pretrained_model_name_or_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
model = LayerSkipLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)
skip_layer_ids = {
"attn": [
2,
15,
18,
],
"mlp": [
2,
15,
18,
]
}
model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)
messages = [
[
{
"role": "user",
"content": "What is the capital of Taiwan. And why?",
},
],
]
# Tokenize
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
).to(device)
prompt_token_num = inputs["input_ids"].shape[-1]
# Original Model
model.set_draft_mode(False)
start_time = time.time()
outputs = model.generate(**inputs, max_new_tokens=512)
total_token_num = outputs.shape[-1]
completion_token_num = total_token_num - prompt_token_num
cost_time = time.time() - start_time
token_per_second = completion_token_num / cost_time
response = tokenizer.batch_decode(outputs)[0]
print(f"{'='*15} Original Model {'='*15}")
print(response)
print()
print(f"Completion Token Number: {completion_token_num}")
print(f"Cost Time: {cost_time}, Speed: {token_per_second} token/sec\n")
# LayerSkip Model
model.set_draft_mode(True)
start_time = time.time()
outputs = model.generate(**inputs, max_new_tokens=512)
total_token_num = outputs.shape[-1]
completion_token_num = total_token_num - prompt_token_num
cost_time = time.time() - start_time
token_per_second = completion_token_num / cost_time
response = tokenizer.batch_decode(outputs)[0]
print(f"{'='*15} LayerSkip Model {'='*15}")
print(response)
print()
print(f"Completion Token Number: {completion_token_num}")
print(f"Cost Time: {cost_time}, Speed: {token_per_second} token/sec\n")
Output:
skip_layer_ids: {'attn': [2, 15, 18], 'mlp': [2, 15, 18]}
=============== Original Model ===============
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei. It is the largest city in Taiwan and serves as the political, economic, and cultural center of the country. The reason for this is that Taipei was established as the capital city in 1949, following the Chinese Civil War, when the government of the Republic of China (ROC) relocated from mainland China to Taiwan. This decision was made to ensure the continuity of the ROC's political and administrative functions, and to maintain its claim to the entirety of China.<|im_end|>
Completion Token Number: 110
Cost Time: 2.2670738697052, Speed: 48.52069509949576 token/sec
=============== LayerSkip Model ===============
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei.<|im_end|>
Completion Token Number: 13
Cost Time: 0.1961832046508789, Speed: 66.26459193147736 token/sec
我是覺得原本的模型肯定比較好,因為確實地解釋了為什麼台北是台灣的首都;但另一方面,目前 SkipLayer 模型其實性能還沒有掉太多,回答還是通順的,並且速度從一秒產生 48.52 Tokens 提速到了 66.26 Tokens。
總之這是一次很有趣的實作,我打算繼續往下深挖,直到把 Self-Speculative Decoding 完整做完。
References
- Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding
- Kangaroo: Lossless Self-Speculative Decoding via Double Early Exiting
Read More