Skip to content

Self-Speculative Decoding Implementation: LayerSkip Transformer

Last Updated on 2024-11-12 by Clay

Introduction

Self-Speculative Decoding is a variant of Speculative Decoding. The original Speculative Decoding method uses a draft model to optimize the inference of the target model. The draft model, which is typically distilled from the target model, offers similar output quality but with several times faster inference speed.

Once the draft model predicts a series of candidate token sequences, the target model verifies these sequences by predicting the next token and obtaining the probability distribution for each previous token predicted by the draft model. With specific verification algorithms, tokens predicted by the draft model can be accepted, enabling the target model to decode multiple tokens in one inference cycle and thus achieve acceleration.

Self-Speculative Decoding addresses the additional VRAM cost associated with loading a separate draft model. Instead, it uses parts of its own neural network layers to emulate a draft model and then verifies using the full network layers as the target model. Detailed explanations can be found in the paper Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding.

Ultimately, the draft model becomes... me, transformed into a draft model!

Today, I’m documenting how to create a LayerSkipModel by simply calling the original Transformer code and adding a skip mechanism. This allows us to bypass self-attention and MLP layers as needed.

However, I haven’t yet completed the mechanism for selecting skip layers through Bayesian Optimization as described in the paper, so I can’t provide results yet.

For a more basic introduction to Speculative Decoding, you can refer to my previous notes:

Additionally, I’ll upload my implementation on GitHub: https://github.com/ccs96307/fast-llm-inference. This repository contains various implementations to accelerate inference, along with references to the papers I consulted. Feel free to bookmark or star it, and I welcome everyone to check it out!


LayerSkip Architecture Implementation

Basically, we only need to rewrite three classes: LlamaDecoderLayer, LlamaModel, and LlamaForCausalLM. I’ve largely copied HuggingFace Transformers’ implementation, so the copyright belongs to them.

Of course, inheritance could be used, but this level of abstraction works perfectly for me while keeping the code additions minimal.

In essence, I did three main things:

  1. Defined the three classes with the LayerSkip prefix and initialized self.draft_mode and self.skip_layer_ids. Each model initialization aligns with its components (LayerSkipLlamaForCausalLM contains LayerSkipLlamaModel, and LayerSkipLlamaModel initializes with LayerSkipLlamaDecoderLayer).
  2. Added set_skip_layer_ids() and set_draft_mode() methods to the internal methods of each class. These settings propagate from the outermost LayerSkipLlamaForCausalLM down to the LayerSkipLlamaDecoderLayer.
  3. In the forward() method of LayerSkipLlamaDecoderLayer, decisions to skip the attention mechanism or MLP layers are made based on whether self.draft_mode is enabled and if the current decoding layer is among self.skip_layer_ids.
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint

from transformers import (
    LlamaConfig,
    LlamaPreTrainedModel,
    GenerationMixin,
)
from transformers.models.llama.modeling_llama import (
    LlamaMLP,
    LlamaRMSNorm,
    LlamaRotaryEmbedding,
    LLAMA_ATTENTION_CLASSES,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_attn_mask_utils import AttentionMaskConverter


class LayerSkipLlamaDecoderLayer(torch.nn.Module):
    def __init__(
        self,
        config: LlamaConfig,
        layer_idx: int,
    ):
        super().__init__()

        # Set skip layer
        skip_layer_ids = {"attn": [], "mlp": []}
        self.draft_mode = False
        self.skip_layer_ids = skip_layer_ids
        self.layer_idx = layer_idx
    
        self.hidden_size = config.hidden_size

        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        self.skip_layer_ids = skip_layer_ids

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # will become mandatory in v4.46
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        if self.draft_mode and self.layer_idx in self.skip_layer_ids["attn"]:
            hidden_states = residual
            self_attn_weights = None
            present_key_value = None
        else:
            hidden_states = self.input_layernorm(hidden_states)

            # Self Attention
            hidden_states, self_attn_weights, present_key_value = self.self_attn(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )
            hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states

        if self.draft_mode and self.layer_idx in self.skip_layer_ids["mlp"]:
            hidden_states = residual
        else:
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.mlp(hidden_states)
            hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs
    

class LayerSkipLlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)

        # Set skip layer
        skip_layer_ids = {"attn": [], "mlp": []}
        self.draft_mode = False
        self.skip_layer_ids = skip_layer_ids

        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = torch.nn.ModuleList(
            [LayerSkipLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = LlamaRotaryEmbedding(config=config)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        self.skip_layer_ids = skip_layer_ids

        for layer in self.layers:
            layer.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode

        for layer in self.layers:
            layer.set_draft_mode(_mode=_mode)

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        # kept for BC (non `Cache` `past_key_values` inputs)
        return_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):
            return_legacy_cache = True
            if past_key_values is None:
                past_key_values = DynamicCache()
            else:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
        # to infer the attention mask.
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        using_static_cache = isinstance(past_key_values, StaticCache)

        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
        if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        dtype, device = input_tensor.dtype, input_tensor.device
        sequence_length = input_tensor.shape[1]
        if using_static_cache:
            target_length = past_key_values.get_max_cache_shape()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
        causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
            attention_mask,
            sequence_length=sequence_length,
            target_length=target_length,
            dtype=dtype,
            device=device,
            cache_position=cache_position,
            batch_size=input_tensor.shape[0],
        )

        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            min_dtype = torch.finfo(dtype).min
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask

    @staticmethod
    def _prepare_4d_causal_attention_mask_with_cache_position(
        attention_mask: torch.Tensor,
        sequence_length: int,
        target_length: int,
        dtype: torch.dtype,
        device: torch.device,
        cache_position: torch.Tensor,
        batch_size: int,
        **kwargs,
    ):
        if attention_mask is not None and attention_mask.dim() == 4:
            # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
            causal_mask = attention_mask
        else:
            min_dtype = torch.finfo(dtype).min
            causal_mask = torch.full(
                (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
            )
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
            if attention_mask is not None:
                causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                    padding_mask, min_dtype
                )

        return causal_mask
    

class LayerSkipLlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)

        # Set skip layer
        skip_layer_ids = {"attn": [], "mlp": []}
        self.draft_mode = False
        self.skip_layer_ids = skip_layer_ids

        self.model = LayerSkipLlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    def set_skip_layer_ids(self, skip_layer_ids: Dict[str, List[int]]):
        assert "attn" in skip_layer_ids and "mlp" in skip_layer_ids, "`skip_layer_ids` need to be set `attn` and `mlp`!"
        assert isinstance(skip_layer_ids["attn"], list), "`skip_layer_ids['attn']` need to be a list!"
        assert isinstance(skip_layer_ids["mlp"], list), "`skip_layer_ids['mlp']` need to be a list!"

        for attn_layer_idx in skip_layer_ids["attn"]:
            assert attn_layer_idx < len(self.model.layers), f"attn_layer_idx {attn_layer_idx} is out of Range ({len(self.model.layers)})" 
            
        for mlp_layer_idx in skip_layer_ids["mlp"]:
            assert mlp_layer_idx < len(self.model.layers), f"mlp_layer_idx {mlp_layer_idx} is out of Range ({len(self.model.layers)})"

        self.skip_layer_ids = skip_layer_ids
        self.model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)

        print("skip_layer_ids:", self.skip_layer_ids)

    def set_draft_mode(self, _mode: bool):
        self.draft_mode = _mode
        self.model.set_draft_mode(_mode=_mode)

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
            logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

        loss = None
        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


Once completed, let’s test the results!


Testing the Effectiveness of the LayerSkip Model

Using a 1.7B model, it seems that skipping only a few layers significantly impacts performance. This may explain why Meta AI's LayerSkip models require training to compensate.

Still, my basic calculations show that LayerSkip indeed saves time! (Which is expected, as some computations are simply omitted.)

Currently, I’ve reached the stage where, after choosing which layers to skip, I can test the effects of Self-Speculative Decoding. However, I still need to implement a method for selecting the layers to skip.

Below is a simple test comparing the original model and the model with a few skipped layers. I chose to skip the attention mechanism and MLP layers at layers 2, 15, and 18.

import time

import torch
from transformers import AutoTokenizer
from layerskip_modeling.modeling_layerskip_llama import LayerSkipLlamaForCausalLM


if __name__ == "__main__":
    pretrained_model_name_or_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    model = LayerSkipLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16).to(device)

    skip_layer_ids = {
        "attn": [
            2,
            15,
            18,
        ],
        "mlp": [
            2,
            15,
            18,
        ]
    }

    model.set_skip_layer_ids(skip_layer_ids=skip_layer_ids)


    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]


    # Tokenize
    input_text=tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    prompt_token_num = inputs["input_ids"].shape[-1]

    # Original Model
    model.set_draft_mode(False)

    start_time = time.time()

    outputs = model.generate(**inputs, max_new_tokens=512)
    total_token_num = outputs.shape[-1]
    completion_token_num = total_token_num - prompt_token_num
    cost_time = time.time() - start_time

    token_per_second = completion_token_num / cost_time
    response = tokenizer.batch_decode(outputs)[0]

    print(f"{'='*15} Original Model {'='*15}")
    print(response)
    print()
    print(f"Completion Token Number: {completion_token_num}")
    print(f"Cost Time: {cost_time}, Speed: {token_per_second} token/sec\n")


    # LayerSkip Model
    model.set_draft_mode(True)

    start_time = time.time()

    outputs = model.generate(**inputs, max_new_tokens=512)
    total_token_num = outputs.shape[-1]
    completion_token_num = total_token_num - prompt_token_num
    cost_time = time.time() - start_time

    token_per_second = completion_token_num / cost_time
    response = tokenizer.batch_decode(outputs)[0]

    print(f"{'='*15} LayerSkip Model {'='*15}")
    print(response)
    print()
    print(f"Completion Token Number: {completion_token_num}")
    print(f"Cost Time: {cost_time}, Speed: {token_per_second} token/sec\n")


Output:

skip_layer_ids: {'attn': [2, 15, 18], 'mlp': [2, 15, 18]}
=============== Original Model ===============
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei. It is the largest city in Taiwan and serves as the political, economic, and cultural center of the country. The reason for this is that Taipei was established as the capital city in 1949, following the Chinese Civil War, when the government of the Republic of China (ROC) relocated from mainland China to Taiwan. This decision was made to ensure the continuity of the ROC's political and administrative functions, and to maintain its claim to the entirety of China.<|im_end|>

Completion Token Number: 110
Cost Time: 2.2670738697052, Speed: 48.52069509949576 token/sec

=============== LayerSkip Model ===============
<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei.<|im_end|>

Completion Token Number: 13
Cost Time: 0.1961832046508789, Speed: 66.26459193147736 token/sec

In my opinion, the original model definitely provides a more comprehensive answer, as it explains why Taipei is Taiwan’s capital. However, the LayerSkip model still delivers coherent responses and speeds up from 48.52 tokens per second to 66.26 tokens per second.

Overall, this was an interesting implementation, and I plan to continue refining it until I have fully implemented Self-Speculative Decoding.


References


Read More

Leave a Reply