Skip to content

Self-Speculative Decoding 實現: 跳層 Transformer 模型實作筆記

Last Updated on 2024-11-10 by Clay

介紹

自推測性解碼(Self-Speculative Decoding)是一個推測性解碼(Speculative Decoding)的變體。原本的 Speculative Decoding 是採用一個草稿模型(draft model)來優化我們真正想要推理的目標模型(target),並且 draft model 擁有與 target model 相似的輸出以及快上幾倍的推理時間,通常是由 target model 蒸餾而來。

一旦 draft model 推測出眾多待驗證的候選詞元序列後,就會由 target model 預測這個候選詞元序列的下一個 Token,並同時得到 target model 對於 draft model 所推測的先前那些 Tokens 各自的機率分佈;之後,就可以以不同的驗證演算法決定是否接受 —— 若是我們接受了許多由 draft model 解碼的 Tokens,則意味著我們的 target model 在一次推理中解碼了許多 Tokens,以此達成加速目的。

而 Self-Speculative Decoding 考量了額外載入 draft model 的 VRAM 開銷,只使用自己的部份神經網路層進行推理來模擬成一個 draft model,再用全部的神經網路層當作 target model 進行驗證,詳細內容可以參考論文 Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding

最後的草稿模型,就是我自身化為草稿模型!

而我今天紀錄的,正是如何透過簡單地呼叫 Transformer 的原始碼,並簡單地添加跳過機制,以此完成一個可以跳過自注意力機制(Self-Attention Mechanism)層和 MLP 層的 LayerSkipModel。

不過,由於我還沒有完成論文中使用貝氏優化(Bayesian Optimization)來挑選跳過層的實驗機制,所以還無法附上結果。

我先前有一些關於 Speculative Decoding 的簡單介紹,可以參考:

同時,我也會把我的實現放在 GitHub: https://github.com/ccs96307/fast-llm-inference,這是一個收錄我各種簡易實現加速推理的專案,同時會附上我參考的論文連結,歡迎收藏或 star,歡迎大家前往哦!


跳層架構實作

基本上,我們只需要重新寫一次 LlamaDecoderLayerLlamaModelLlamaForCausalLM 這三個類別(classes)即可,大部分程式碼我都是直接複製 HuggingFace Transformers 的實現,所以版權歸他們所有,嗯沒錯。

當然你可以用繼承的方式去寫,但對我來說,維持在這一層的抽象程度剛剛好,要添加的程式碼也最少。

簡單來說我做了三件事:

  1. 把三個類別加上 LayerSkip 前綴定義好,並在初始化時明確定義 self.draft_modeself.skip_layer_ids,並把初始化的模組設定好(LayerSkipLlamaForCausalLM 裡面初始化的模型是 LayerSkipLlamaModelLayerSkipLlamaModel 裡面初始化的解碼層是 LayerSkipLlamaDecoderLayer
  2. 把三個類別的內部方法添加 set_skip_layer_ids()set_draft_mode() 兩種,並且會從最外層的 LayerSkipLlamaForCausalLM 設定好後一路往內部傳遞到 LayerSkipLlamaDecoderLayer 為止
  3. LayerSkipLlamaDecoderLayerforward() 方法中,根據 self.draft_mode 是否啟用、以及 self.skip_layer_ids 是否有包含當前解碼層,來決定是否跳過注意力機制或是 MLP
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint

from transformers import (
    LlamaConfig,
    LlamaPreTrainedModel,
    GenerationMixin,
)
from transformers.models.llama.modeling_llama import (
    LlamaMLP,
    LlamaRMSNorm,
    LlamaRotaryEmbedding,
    LLAMA_ATTENTION_CLASSES,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_attn_mask_utils import AttentionMaskConverter


class LayerSkipLlamaDecoderLayer(torch.nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        layer_idx: int,
    ):
        super().__init__()

        # Set skip layer
        skip_layer_ids = {"attn": [], "mlp": []}
        self.draft_mode = False
        self.skip_layer_ids = skip_layer_ids
        self.layer_idx = layer_idx
    
        self.hidden_size = config.hidden_size

        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        self.skip_layer_ids = skip_layer_ids

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        if self.draft_mode and self.layer_idx in self.skip_layer_ids["attn"]:
            hidden_states = residual
            self_attn_weights = None
            present_key_value = None
        else:
            hidden_states = self.input_layernorm(hidden_states)

            # Self Attention
            hidden_states, self_attn_weights, present_key_value = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )
            hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states

        if self.draft_mode and self.layer_idx in self.skip_layer_ids["mlp"]:
            hidden_states = residual
        else:
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.mlp(hidden_states)
            hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs
    

class LayerSkipLlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        # Set skip layer
        skip_layer_ids = {"attn": [], "mlp": []}
        self.draft_mode = False
        self.skip_layer_ids = skip_layer_ids

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = torch.nn.ModuleList(
            [LayerSkipLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        self.skip_layer_ids = skip_layer_ids

        for layer in self.layers:
            layer.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode

        for layer in self.layers:
            layer.set_draft_mode(_mode=_mode)

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):
            return_legacy_cache = True
            if past_key_values is None:
                past_key_values = DynamicCache()
            else:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
        # to infer the attention mask.
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)

        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]
        if using_static_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        if attention_mask is not None and attention_mask.dim() == 4:
            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
            causal_mask = attention_mask
        else:
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask
    

class LayerSkipLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)

        # Set skip layer
        skip_layer_ids = {"attn": [], "mlp": []}
        self.draft_mode = False
        self.skip_layer_ids = skip_layer_ids

        self.model = LayerSkipLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        assert "attn" in skip_layer_ids and "mlp" in skip_layer_ids, "`skip_layer_ids` need to be set `attn` and `mlp`!"
        assert isinstance(skip_layer_ids["attn"], list), "`skip_layer_ids['attn']` need to be a list!"
        assert isinstance(skip_layer_ids["mlp"], list), "`skip_layer_ids['mlp']` need to be a list!"

        for attn_layer_idx in skip_layer_ids["attn"]:
            assert attn_layer_idx < len(self.model.layers), f"attn_layer_idx {attn_layer_idx} is out of Range ({len(self.model.layers)})" 
            
        for mlp_layer_idx in skip_layer_ids["mlp"]:
            assert mlp_layer_idx < len(self.model.layers), f"mlp_layer_idx {mlp_layer_idx} is out of Range ({len(self.model.layers)})"

        self.skip_layer_ids = skip_layer_ids
        self.model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

        print("skip_layer_ids:", self.skip_layer_ids)

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode
        self.model.set_draft_mode(_mode=_mode)

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
            logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


完成後,我們來測試看看吧!


測試 LayerSkip 模型的效果

不知道是不是因為我選用的模型是 1.7B 量級的,總感覺只跳了幾層,性能損失得非常非常快啊 —— 或許這就是 Meta AI 他們後來的 LayerSkip 模型必定會經過訓練階段的原因之一。

不過,根據我粗淺的計算,LayerSkip 確實還是會節省時間的!(這也是當然,畢竟有些計算直接省略了)

總之,目前的進度已經來到若是我決定好要跳過哪些層,就能測試 Self-Speculative Decoding 效果的階段了,不過我還需要實作一下我的跳過層檢索方式。

以下是個簡單的測試,比較原始模型和跳過幾層的模型差別。這裡我定義跳過的層,分別是注意力機制層和 MLP 層的第 2、15、18 層。

import time

import torch
from transformers import AutoTokenizer
from layerskip_modeling.modeling_layerskip_llama import LayerSkipLlamaForCausalLM


if __name__ == "__main__":
    pretrained_model_name_or_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    model = LayerSkipLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)

    skip_layer_ids = {
        "attn": [
            2,
            15,
            18,
        ],
        "mlp": [
            2,
            15,
            18,
        ]
    }

    model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)


    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]


    # Tokenize
    input_text=tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    prompt_token_num = inputs["input_ids"].shape[-1]

    # Original Model
    model.set_draft_mode(False)

    start_time = time.time()

    outputs = model.generate(**inputs, max_new_tokens=512)
    total_token_num = outputs.shape[-1]
    completion_token_num = total_token_num - prompt_token_num
    cost_time = time.time() - start_time

    token_per_second = completion_token_num / cost_time
    response = tokenizer.batch_decode(outputs)[0]

    print(f"{'='*15} Original Model {'='*15}")
    print(response)
    print()
    print(f"Completion Token Number: {completion_token_num}")
    print(f"Cost Time: {cost_time}, Speed: {token_per_second} token/sec\n")


    # LayerSkip Model
    model.set_draft_mode(True)

    start_time = time.time()

    outputs = model.generate(**inputs, max_new_tokens=512)
    total_token_num = outputs.shape[-1]
    completion_token_num = total_token_num - prompt_token_num
    cost_time = time.time() - start_time

    token_per_second = completion_token_num / cost_time
    response = tokenizer.batch_decode(outputs)[0]

    print(f"{'='*15} LayerSkip Model {'='*15}")
    print(response)
    print()
    print(f"Completion Token Number: {completion_token_num}")
    print(f"Cost Time: {cost_time}, Speed: {token_per_second} token/sec\n")


Output:

skip_layer_ids: {'attn': [2, 15, 18], 'mlp': [2, 15, 18]}
=============== Original Model ===============
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei. It is the largest city in Taiwan and serves as the political, economic, and cultural center of the country. The reason for this is that Taipei was established as the capital city in 1949, following the Chinese Civil War, when the government of the Republic of China (ROC) relocated from mainland China to Taiwan. This decision was made to ensure the continuity of the ROC's political and administrative functions, and to maintain its claim to the entirety of China.<|im_end|>

Completion Token Number: 110
Cost Time: 2.2670738697052, Speed: 48.52069509949576 token/sec

=============== LayerSkip Model ===============
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei.<|im_end|>

Completion Token Number: 13
Cost Time: 0.1961832046508789, Speed: 66.26459193147736 token/sec

我是覺得原本的模型肯定比較好,因為確實地解釋了為什麼台北是台灣的首都;但另一方面,目前 SkipLayer 模型其實性能還沒有掉太多,回答還是通順的,並且速度從一秒產生 48.52 Tokens 提速到了 66.26 Tokens。

總之這是一次很有趣的實作,我打算繼續往下深挖,直到把 Self-Speculative Decoding 完整做完。


References


Read More

Leave a Reply