Skip to content

Kangaroo: Inference Acceleration Architecture Implementation

Last Updated on 2024-12-10 by Clay

Introduction

Kangaroo is an implementation of Self-Speculative Decoding that introduces a trainable adapter layer. Over the past few weeks, I have been working on fine-tuning its adapter layer and have achieved some preliminary results, which I am documenting here.


Kangaroo Architecture

Kangaroo is a variant of Self-Speculative Decoding, which itself is a branch of Speculative Decoding.

In the original Speculative Decoding method, there is a draft model and a target model (the one we aim to accelerate). Typically, the draft model and target model share the same vocabulary and the draft model performs inference much faster. The draft model speculatively generates several tokens, which are then validated by the target model in a single pass.

In the original paper (Fast Inference from Transformers via Speculative Decoding), the research team designed a validation algorithm for the target model to ensure that the output probability distribution after speculative decoding matches that of the original target model, thereby achieving acceleration without altering the decoding results.

Self-Speculative Decoding (Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding) leverages parts of the target model's own network as the draft model. Simply put, in a Transformer block, each layer consists of an attention layer and an MLP layer, and we can choose to skip these layers selectively.

The choice of which layers to skip is critical. Skipping too few layers won't yield significant speedup, thus failing to accelerate the target model. Conversely, skipping too many layers can cause the draft model's decoding to diverge significantly from the target model, potentially rendering the draft model's outputs unacceptable to the target model.

To balance the trade-off between inference speed and accuracy, Self-Speculative Decoding employs Bayesian Optimization to explore layer-skipping strategies that satisfy both requirements.

I previously conducted an experiment: Self-Speculative Decoding Implementation: LayerSkip Model, Bayesian Optimization, and Adaptive Draft-Exiting Mechanism (with Gemma-2-9B-it Results).

However, the experimental results were dismal. Due to hardware constraints, I could only test up to the 9B scale, which showed no acceleration at all. In retrospect, the results presented in the paper starting from 13B models make sense; perhaps smaller-scale models do not perform well using partial networks as draft models.

Thus, I decided to take on the challenge of implementing Kangaroo. The principle of Kangaroo aligns with Self-Speculative Decoding, but it allows for model training. Of course, training the draft model might compromise the target model's performance, as the draft model is part of the target model, and fine-tuning could lead to some loss of capability in the target model.

This led to the following architecture proposed by Kangaroo:

The draft model in Kangaroo primarily consists of shallow sub-networks. The hidden states are then processed through an additional trainable Adapter-Network, which bridges the shallow sub-network and the LM Head to produce logits.

This Adapter-Network is trainable, and during training, we compute the logits output from both the Adapter-Network and the last layer of the remaining layers through the LM Head, optimizing their similarity via gradient descent.

From another perspective, the draft model can be seen as an alternative path that exits early from the inference process to directly reach the LM Head. However, directly skipping to the last layer often produces nonsensical tokens (this is normal, as the shallow network was not intended to directly connect to the final layer in the original pretraining), which is why an adapter is necessary to bridge the shallow sub-network's output to the LM Head.

In this way, we can fine-tune the shallow sub-network-based draft model to align more closely with the target model without compromising the original model.

To conclude, while I was unable to achieve acceleration with an 8B model during Self-Speculative Decoding experiments, the final Kangaroo training achieved an inference speed improvement of approximately 10% ~ 15%. Honestly, after the initial setbacks, seeing this result felt like a tearful victory for me XDDD.


Code Implementation

I wrote extensive code for Kangaroo implementation, which you can find on my GitHub: https://github.com/ccs96307/fast-llm-inference. It includes various implementations of inference acceleration techniques. Feel free to contact me for discussions.

Here, I will briefly introduce three scripts:

  • modeling_kangaroo_llama3.py: The modified Kangaroo model based on the Llama architecture
  • train.py: The training script
  • test_kangaroo.py: The script to test Kangaroo's acceleration performance


Starting with Kangaroo's architecture implementation, I abstractly divided the LlamaForCausalLM model into shallow_layers and remaining_layers extracted from self.model.layers. My shallow network only contains two layers based on the original paper's design.

For the Adapter, I initially implemented the Attention module as described in the paper and later extended it to a full decoder layer with MLP modules. This was due to convergence issues with just the Attention module, so I added more parameters in hopes of achieving better results.

Regarding the loss function, I introduced an additional hard_labels loss term, which is weighted and combined with the soft-label cross-entropy loss from the original paper. This approach seemed slightly more effective in my experiments, though I couldn't replicate the paper's results fully—possibly due to some suboptimal implementation details.

I also implemented a prepare_dataset() method compatible with the official Kangaroo project, extracting outputs from early exit layers and the final layer as training data. This optimization reduces intermediate inference during training. However, using the ShareGPT4 dataset required nearly 900GB of disk space across my servers, so I implemented both the optimized training method and the native approach of running both the draft model and target model.

from typing import Any, List, Optional, Tuple, Union

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

from dataclasses import dataclass
import json

import torch
torch.set_default_dtype(torch.bfloat16)

from torch.nn.utils.rnn import pad_sequence
from transformers import LlamaForCausalLM
from transformers.models.llama.modeling_llama import Cache, DynamicCache, LlamaSdpaAttention, LlamaDecoderLayer, LlamaRMSNorm
from transformers.modeling_outputs import CausalLMOutputWithPast

from sampling.sampling import sample_next_token
from utils.utils import calculate_continuous_acceptance


@dataclass
class KangarooModelMode:
    draft_only_mode: str = "draft_only"
    target_only_mode: str = "target_only"
    train_mode: str = "train"


@dataclass
class AdapterMode:
    attention_only_mode: str = "attention_only"
    mlp_only_mode: str = "mlp_only"
    decoder_layer_mode: str = "decoder_layer"


class KangarooLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.adapter_layer_mode = None
        self.draft_mode_adapter_layer = None
        self.shallow_layers = None
        self.mode = KangarooModelMode.target_only_mode
        self.confidence_threshold = 0.5
        self.accept_rate = 0
        self.total_accept_tokens = 0
        self.total_draft_generated_token = 0
        self.draft_temperature = 1.0
        self.target_temperature = 1.0
        self.alpha = 0.8
        self.shallow_layer_num = 10
    
    def set_skip_layer(self, shallow_layer_num: int) -> None:
        self.shallow_layer_num = shallow_layer_num
        self.shallow_layers = self.model.layers[:shallow_layer_num]
        self.remaining_layers = self.model.layers[shallow_layer_num:]

    def set_adapter_layer(self, _mode: str) -> None:
        self.adapter_layer_mode = _mode

        if _mode == AdapterMode.attention_only_mode:
            self.attn_input_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
            self.attn_output_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)

            self.draft_mode_adapter_layer = LlamaSdpaAttention(
                config=self.config,
                layer_idx=self.config.num_hidden_layers,
            )
        elif _mode == AdapterMode.decoder_layer_mode:
            self.draft_mode_adapter_layer = LlamaDecoderLayer(
                config=self.config,
                layer_idx=self.config.num_hidden_layers,
            )

    def set_draft_mode(self) -> None:
        self.mode = KangarooModelMode.draft_only_mode

    def set_target_mode(self) -> None:
        self.mode = KangarooModelMode.target_only_mode

    def set_train_mode(self) -> None:
        self.mode = KangarooModelMode.train_mode

    def save_head(
        self,
        save_dir: str,
    ) -> None:
        """
        Save the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num to the specified directory.

        Args:
            save_dir (str): Directory to save the adapter parameters.
        """
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Save the adapter's state dictionary
        head_path = os.path.join(save_dir, "lm_head.pt")
        torch.save(self.lm_head.state_dict(), head_path)
        print(f"`lm_head` saved at {head_path}")

    def save_norm(
        self,
        save_dir: str,
    ) -> None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Save the adapter's state dictionary
        norm_path = os.path.join(save_dir, "norm.pt")
        torch.save(self.model.norm.state_dict(), norm_path)
        print(f"`norm` saved at {norm_path}")
    
    def save_adapter(
        self,
        save_dir: str,
        train_loss_history: List[float],
        eval_loss_history: List[float],
    ) -> None:
        """
        Save the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num to the specified directory.

        Args:
            save_dir (str): Directory to save the adapter parameters.
        """
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Save the adapter's state dictionary
        adapter_path = os.path.join(save_dir, "draft_adapter.pt")
        torch.save(self.draft_mode_adapter_layer.state_dict(), adapter_path)
        print(f"Draft adapter saved at {adapter_path}")

        # Save additional information (loss_history and shallow_layer_num)
        metadata = {
            "train_loss_history": train_loss_history,
            "eval_loss_history": eval_loss_history,
            "shallow_layer_num": self.shallow_layer_num
        }
        metadata_path = os.path.join(save_dir, "adapter_metadata.json")
        with open(metadata_path, "w") as f:
            json.dump(metadata, f)
        print(f"Adapter metadata saved at {metadata_path}")

    def load_adapter(self, load_dir: str) -> None:
        """
        Load the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num from the specified directory.

        Args:
            load_dir (str): Directory to load the adapter parameters from.

        Raises:
            FileNotFoundError: If the adapter file does not exist in the specified directory.
        """
        adapter_path = os.path.join(load_dir, "draft_adapter.pt")
        if not os.path.exists(adapter_path):
            raise FileNotFoundError(f"Draft adapter not found at {adapter_path}")
        
        # Load the adapter's state dictionary
        state_dict = torch.load(adapter_path, map_location=self.device)
        self.draft_mode_adapter_layer.load_state_dict(state_dict=state_dict)
        print(f"Draft adapter loaded from {adapter_path}")

    def prepare_dataset(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if self.shallow_layers is None:
            raise AttributeError(f"You do not set the `shallow_layers`!")
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Model
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.model.gradient_checkpointing and self.model.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        # kept for BC (non `Cache` `past_key_values` inputs)
        if use_cache and not isinstance(past_key_values, Cache):
            if past_key_values is None:
                past_key_values = DynamicCache()
            else:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self.model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

        # decoder layers
        for decoder_layer in self.shallow_layers:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

            hidden_states = layer_outputs[0]

        shallow_hidden_states = hidden_states

        # Remaining decoder layers
        for decoder_layer in self.remaining_layers:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.model.norm(hidden_states)

        return shallow_hidden_states, hidden_states

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if self.mode == KangarooModelMode.target_only_mode:
            return super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                cache_position=cache_position,
                num_logits_to_keep=num_logits_to_keep,
                **loss_kwargs,
            )
        elif self.mode == KangarooModelMode.draft_only_mode:
            if self.shallow_layers is None:
                raise AttributeError(f"You do not set the `shallow_layers`!")

            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict

            # Model
            if (input_ids is None) ^ (inputs_embeds is not None):
                raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

            if self.model.gradient_checkpointing and self.model.training and use_cache:
                use_cache = False

            if inputs_embeds is None:
                inputs_embeds = self.model.embed_tokens(input_ids)

            # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = False
            if use_cache and not isinstance(past_key_values, Cache):
                return_legacy_cache = True
                if past_key_values is None:
                    past_key_values = DynamicCache()
                else:
                    past_key_values = DynamicCache.from_legacy_cache(past_key_values)

            if cache_position is None:
                past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                cache_position = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                )
            if position_ids is None:
                position_ids = cache_position.unsqueeze(0)

            causal_mask = self.model._update_causal_mask(
                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
            )
            hidden_states = inputs_embeds

            # create position embeddings to be shared across the decoder layers
            position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

            # decoder layers
            all_hidden_states = () if output_hidden_states else None
            all_self_attns = () if output_attentions else None
            next_decoder_cache = None

            for decoder_layer in self.shallow_layers:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            # add hidden states from the last decoder layer
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            next_cache = next_decoder_cache if use_cache else None
            if return_legacy_cache:
                next_cache = next_cache.to_legacy_cache()

            # Adapter
            if self.adapter_layer_mode == AdapterMode.attention_only_mode:
                residual = hidden_states
                hidden_states = self.attn_input_norm(hidden_states)

                hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
                    hidden_states=hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )
                hidden_states = residual + hidden_states
                hidden_states = self.attn_output_norm(hidden_states)

            elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
                layer_outputs = self.draft_mode_adapter_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )
                hidden_states = layer_outputs[0]
                hidden_states = self.model.norm(hidden_states)

            if self.config.pretraining_tp > 1:
                lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
                logits = [torch.nn.functional.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
                logits = torch.cat(logits, dim=-1)
            else:
                # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
                logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

            loss = None
            if labels is not None:
                loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)

            return CausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=past_key_values,
                hidden_states=hidden_states,
                attentions=all_self_attns,
            )
        elif self.mode == KangarooModelMode.train_mode:
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict

            if (input_ids is None) ^ (inputs_embeds is not None):
                raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

            if self.model.gradient_checkpointing and self.model.training and use_cache:
                use_cache = False

            if inputs_embeds is None:
                inputs_embeds = self.model.embed_tokens(input_ids)

            # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = False
            if use_cache and not isinstance(past_key_values, Cache):
                return_legacy_cache = True
                if past_key_values is None:
                    past_key_values = DynamicCache()
                else:
                    past_key_values = DynamicCache.from_legacy_cache(past_key_values)

            if cache_position is None:
                past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                cache_position = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                )
            if position_ids is None:
                position_ids = cache_position.unsqueeze(0)

            causal_mask = self.model._update_causal_mask(
                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
            )
            hidden_states = inputs_embeds

            # create position embeddings to be shared across the decoder layers
            position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

            # decoder layers
            all_hidden_states = () if output_hidden_states else None
            all_self_attns = () if output_attentions else None
            next_decoder_cache = None

            for decoder_layer in self.shallow_layers:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            # Cache hidden states
            remaining_hidden_states = hidden_states

            # Add hidden states from the last decoder layer
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            next_cache = next_decoder_cache if use_cache else None
            if return_legacy_cache:
                next_cache = next_cache.to_legacy_cache()

            # Adapter
            if self.adapter_layer_mode == AdapterMode.attention_only_mode:
                residual = hidden_states
                hidden_states = self.attn_input_norm(hidden_states)

                hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
                    hidden_states=hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                hidden_states = residual + hidden_states
                hidden_states = self.attn_output_norm(hidden_states)

            elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
                layer_outputs = self.draft_mode_adapter_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )
                hidden_states = layer_outputs[0]
                hidden_states = self.model.norm(hidden_states)

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            # Remaining decoder layers
            for decoder_layer in self.remaining_layers:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = decoder_layer(
                    remaining_hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                remaining_hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            remaining_hidden_states = self.model.norm(remaining_hidden_states)

            # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
            draft_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) / self.draft_temperature
            target_logits = self.lm_head(remaining_hidden_states[:, -num_logits_to_keep:, :]) / self.target_temperature

            # Compute the log probabilities for both models
            draft_log_probs = torch.nn.functional.log_softmax(draft_logits, dim=-1)
            target_log_probs = torch.nn.functional.log_softmax(target_logits, dim=-1)
            target_probs = torch.nn.functional.softmax(target_logits, dim=-1)

            # Cross-entropy loss between target and draft model predictions
            # kl_loss = torch.nn.functional.kl_div(draft_log_probs, target_probs, reduction="batchmean")
            hard_labels = torch.argmax(target_probs, dim=-1)
            soft_label_cross_entropy_loss = -(target_probs * draft_log_probs).sum(dim=-1).mean()
            hard_label_loss = torch.nn.functional.cross_entropy(
                draft_logits.view(-1, draft_logits.size(-1)),  # Flatten logits
                hard_labels.view(-1)  # Flatten hard labels
            )

            loss = self.alpha * soft_label_cross_entropy_loss + (1 - self.alpha) * hard_label_loss

            return CausalLMOutputWithPast(
                loss=loss,
                logits=target_logits,
                past_key_values=past_key_values,
                hidden_states=hidden_states,
                attentions=all_self_attns,
            )

    @torch.no_grad()
    def kangaroo_generate(
        self,
        eos_token_id: int,
        pad_token_id: int,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        max_new_tokens: int = 100,
        **kwargs,
    ) -> torch.LongTensor:
        if self.shallow_layers is None:
            raise AttributeError(f"You do not set the `shallow_layers`!")

        confidence_score = 1.0
        total_generate_tokens = 0
        
        while total_generate_tokens < max_new_tokens:
            draft_generate_tokens = 0
            draft_probs = []

            while confidence_score >= self.confidence_threshold:
                output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
                output_hidden_states = (
                    output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
                )
                return_dict = return_dict if return_dict is not None else self.config.use_return_dict

                if (input_ids is None) ^ (inputs_embeds is not None):
                    raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

                if self.model.gradient_checkpointing and self.model.training and use_cache:
                    use_cache = False

                if inputs_embeds is None:
                    inputs_embeds = self.model.embed_tokens(input_ids)

                # kept for BC (non `Cache` `past_key_values` inputs)
                return_legacy_cache = False
                if use_cache and not isinstance(past_key_values, Cache):
                    return_legacy_cache = True
                    if past_key_values is None:
                        past_key_values = DynamicCache()
                    else:
                        past_key_values = DynamicCache.from_legacy_cache(past_key_values)

                if cache_position is None:
                    past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                    cache_position = torch.arange(
                        past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                    )
                if position_ids is None:
                    position_ids = cache_position.unsqueeze(0)

                causal_mask = self.model._update_causal_mask(
                    attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
                )
                hidden_states = inputs_embeds

                # create position embeddings to be shared across the decoder layers
                position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

                # decoder layers
                all_hidden_states = () if output_hidden_states else None
                all_self_attns = () if output_attentions else None
                next_decoder_cache = None

                for decoder_layer in self.shallow_layers:
                    if output_hidden_states:
                        all_hidden_states += (hidden_states,)

                    layer_outputs = decoder_layer(
                        hidden_states,
                        attention_mask=causal_mask,
                        position_ids=position_ids,
                        past_key_value=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                    )

                    hidden_states = layer_outputs[0]

                    if use_cache:
                        next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                    if output_attentions:
                        all_self_attns += (layer_outputs[1],)

                # Cache hidden states
                remaining_hidden_states = hidden_states

                # Add hidden states from the last decoder layer
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                next_cache = next_decoder_cache if use_cache else None
                if return_legacy_cache:
                    next_cache = next_cache.to_legacy_cache()

                # Adapter
                if self.adapter_layer_mode == AdapterMode.attention_only_mode:
                    residual = hidden_states
                    hidden_states = self.attn_input_norm(hidden_states)
                    
                    hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        past_key_values=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                    )

                    hidden_states = residual + hidden_states
                    hidden_states = self.attn_output_norm(hidden_states)

                elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
                    layer_outputs = self.draft_mode_adapter_layer(
                        hidden_states,
                        attention_mask=causal_mask,
                        position_ids=position_ids,
                        past_key_value=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                    )
                    hidden_states = layer_outputs[0]
                    hidden_states = self.model.norm(hidden_states)

                    if use_cache:
                        next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                    if output_attentions:
                        all_self_attns += (layer_outputs[1],)

                # Re-init
                inputs_embeds = None
                position_ids = None
                cache_position = None

                # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
                draft_logits = self.lm_head(hidden_states[:, -1:, :])

                # Sampling and get the probabilities
                next_tokens, probs = sample_next_token(
                    logits=draft_logits,
                    prefix_token_ids=input_ids,
                )

                draft_probs.append(probs)
                input_ids = torch.cat([input_ids, next_tokens[:, -1:]], dim=-1)
                attention_mask = torch.cat([attention_mask, torch.ones(attention_mask.shape[0], 1).to(input_ids.device)], dim=-1)

                draft_generate_tokens += 1
                self.total_draft_generated_token += 1

                # Support bs=1
                decode_token_id = next_tokens[:, -1].item()
                if probs[:, -1, decode_token_id] < self.confidence_threshold or total_generate_tokens + draft_generate_tokens >= max_new_tokens:
                    draft_probs = torch.cat(draft_probs, dim=1)
                    break

            # Use whole model for evaluating
            for decoder_layer in self.remaining_layers:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = decoder_layer(
                    remaining_hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                remaining_hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            remaining_hidden_states = self.model.norm(remaining_hidden_states)
            num_logits_to_keep = draft_probs.shape[1]
            target_logits = self.lm_head(remaining_hidden_states[:, -num_logits_to_keep:, :])

            target_input_ids = input_ids[:, :-1]

            next_tokens, target_probs = sample_next_token(
                logits=target_logits,
                prefix_token_ids=target_input_ids,
                probs_num=num_logits_to_keep,
            )

            # Evaluation
            expanded_indices = input_ids[:, -draft_probs.shape[1]:].unsqueeze(-1)

            # Get each probilities
            selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices).squeeze(-1)
            selected_eval_probs = torch.gather(target_probs, dim=-1, index=expanded_indices).squeeze(-1)

            # Compare draft_prob and eval_prob, and check the reject_mask
            mask_to_reject = selected_draft_probs > selected_eval_probs

            # Calculate reject probabilty 1 - (eval_prob / draft_prob)
            rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

            # Generate random values to determined accept or reject
            random_values = torch.rand_like(rejection_probs)
            rejection_decisions = random_values < rejection_probs

            # Get the final reject masks
            rejection_masks = mask_to_reject & rejection_decisions
            acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
            acceptance_mask[rejection_masks] = False

            # Concat `input_ids`
            if torch.all(acceptance_mask):
                total_generate_tokens += draft_generate_tokens
            else:
                new_input_ids = []
                new_attention_mask = []
                is_end = False

                for batch_idx in range(next_tokens.shape[0]):
                    gamma = next_tokens.shape[1]
                    start_idx = input_ids.shape[1] - gamma

                    for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                        total_generate_tokens += 1
                        if (acceptance_mask[batch_idx][pos_idx] and input_ids[batch_idx][start_idx+pos_idx].item() == eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                            input_ids[batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                            new_input_ids.append(input_ids[batch_idx][:start_idx+pos_idx+1])
                            new_attention_mask.append(attention_mask[batch_idx][:start_idx+pos_idx+1])
                            
                            is_end = input_ids[batch_idx][start_idx+pos_idx].item() == eos_token_id
                            break

                input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=pad_token_id)
                attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

            self.total_accept_tokens += calculate_continuous_acceptance(acceptance_mask=acceptance_mask)
            self.accept_rate = self.total_accept_tokens / self.total_draft_generated_token
        
            if is_end:
                break

        return {"input_ids": input_ids}

For the training script, there’s nothing particularly noteworthy; it simply declares the dataset and trains the model in a standard manner.

from typing import Dict, List

import os
from dataclasses import dataclass

from datasets import load_dataset

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.model_selection import train_test_split

from kangaroo_modeling.modeling_kangaroo_llama3 import KangarooLlamaForCausalLM


class CustomDataset(Dataset):
    def __init__(self, inputs: Dict[str, torch.LongTensor], device: torch.DeviceObjType):
        self.inputs = inputs
        self.device = device

    def __len__(self) -> int:
        return self.inputs.input_ids.shape[0]
    
    def __getitem__(self, index: int):
        return (
            self.inputs.input_ids[index].to(self.device),
            self.inputs.attention_mask[index].to(self.device),
        )
    

def main() -> None:
    # Settings
    epochs = 100
    batch_size = 4
    max_length = 512
    lr = 5e-5
    shallow_layer_num = 2
    adapter_mode = "attention_only"
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # Load model and tokenizer
    pretrained_model_name_or_path = "../models/meta-llama--Meta-Llama-3.1-8B-Instruct"
    # pretrained_model_name_or_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct"
    model = KangarooLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token

    model.set_skip_layer(shallow_layer_num=shallow_layer_num)
    model.set_adapter_layer(_mode=adapter_mode)
    model.set_train_mode()

    model = model.to(device)

    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze adapter layer
    model.draft_mode_adapter_layer.train()
    for param in model.draft_mode_adapter_layer.parameters():
        param.requires_grad = True

    if hasattr(model, "attn_input_norm"):
        print("Attention-Adapter!")
        
        for param in model.attn_input_norm.parameters():
            param.requires_grad = True

        for param in model.attn_output_norm.parameters():
            param.requires_grad = True

    # Load dataset
    dataset = load_dataset("shibing624/sharegpt_gpt4")
    
    samples = dataset["train"]["conversations"]
    samples = [[{"role": sample[0]["from"].replace("human", "user").replace("gpt", "assistant"), "content": sample[0]["value"]}] for sample in samples]

    train_samples, eval_samples = train_test_split(samples, test_size=0.1, random_state=2999)
    print(len(samples))

    # Tokenized
    train_inputs = tokenizer(
        [tokenizer.apply_chat_template(messages, tokenize=False) for messages in train_samples],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    )

    eval_inputs = tokenizer(
        [tokenizer.apply_chat_template(messages, tokenize=False) for messages in eval_samples],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    )

    train_dataset = CustomDataset(inputs=train_inputs, device=device)
    eval_dataset = CustomDataset(inputs=eval_inputs, device=device)

    # Dataloader
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

    # Optimizer
    optimizer = torch.optim.AdamW(model.draft_mode_adapter_layer.parameters(), lr=lr)

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        train_loss_history = []
        eval_loss_history = []

        for batch_idx, batch in enumerate(train_dataloader, 1):
            input_ids, attention_mask = batch

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                num_logits_to_keep=max_length,
            )

            # Zero gradients
            optimizer.zero_grad()

            # Calculate loss
            loss = outputs.loss
            total_loss += loss.item()
            train_loss_history.append(loss.item())

            # Backward pass
            loss.backward()

            # Optimizer step
            optimizer.step()

            # Log training loss
            avg_loss = total_loss / batch_idx
            print(f"Train - Epoch [{epoch + 1}/{epochs}] Steps [{batch_idx}/{len(train_dataloader)}], Training Loss: {avg_loss:.4f}")

        # Evaluate the model
        model.eval()
        eval_loss = 0
        with torch.no_grad():
            for batch_idx, batch in enumerate(eval_dataloader, 1):
                input_ids, attention_mask = batch

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    num_logits_to_keep=max_length,
                )
                eval_loss += outputs.loss.item()
                eval_loss_history.append(outputs.loss.item())

                avg_loss = eval_loss / batch_idx
                print(f"Eval - Epoch [{epoch + 1}/{epochs}] Steps [{batch_idx}/{len(eval_dataloader)}], Eval Loss: {avg_loss:.4f}")

        # Save model checkpoint
        save_dir = "./checkpoints/checkpoints_hce_attn_20241209/"
        save_path = os.path.join(save_dir, f"epoch_{epoch+1}")
        model.save_adapter(
            save_path,
            train_loss_history=train_loss_history,
            eval_loss_history=eval_loss_history,
        )
        print(f"Adapter checkpoint saved at {save_path}")


if __name__ == "__main__":
    main()


After training, I attempted to plot the loss curve:

As shown, the evaluation loss plateaued at approximately 0.88. In this case, we can proceed to test the speedup achieved by the Kangaroo architecture post-training.

from typing import Dict, List, Optional, Tuple

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import copy
import time

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from kangaroo_modeling.modeling_kangaroo_llama3 import KangarooLlamaForCausalLM
from sampling.sampling import sample_next_token


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance

def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_probs = []

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

    return inputs, torch.cat(draft_probs, dim=1)


def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        diff_probs=draft_probs,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)    

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


def run_test() -> None:
    # Device
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Model path 
    pretrained_model_name_or_path = "../models/meta-llama--Meta-Llama-3.1-8B-Instruct"
    adapter_dir = "checkpoints/checkpoints_hce_decoder_layer_20241205/epoch_45/"

    # Load Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token

    # Load Model
    model = KangarooLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
    model.set_skip_layer(shallow_layer_num=2)
    model.set_adapter_layer("decoder_layer")
    if adapter_dir:
        model.load_adapter(adapter_dir)

    model = model.to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]

    input_text=tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    # Warm up the model (CUDA)
    inputs_dummy = {k: v.clone() for k, v in inputs.items()}
    with torch.no_grad():
        model.set_draft_mode()
        model(**inputs_dummy)
        model.set_target_mode()
        model(**inputs_dummy)
    torch.cuda.synchronize()

    # Record
    raw_inputs = copy.deepcopy(inputs)
    raw_token_num = raw_inputs["input_ids"].shape[1]
    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = 1
    max_new_tokens = 100
    is_end = False

    start_time = time.time()

    while not is_end:
        # Draft model
        model.set_draft_mode()
        target_inputs, draft_probs = drafter_speculative_decode(
            draft_model=model,
            draft_tokenizer=tokenizer,
            inputs=inputs,
            gamma=gamma,
            temperature=0,
        )

        total_draft_tokens += gamma

        # Target model
        model.set_target_mode()
        outputs, is_end, accept_tokens = target_speculative_decode(
            target_model=model,
            target_tokenizer=tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
            temperature=1,
        )

        total_accept_tokens += accept_tokens

        inputs = outputs

        if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
            break

    generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
    spent_time = time.time() - start_time

    print(f"Generate token number: {generate_token_num}")
    print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
    print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
    print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")

    # Normal Target Model Speed
    inputs = copy.deepcopy(raw_inputs)
    start_time = time.time()
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=model,
        draft_tokenizer=tokenizer,
        inputs=inputs,
        gamma=max_new_tokens,
    )

    spent_time = time.time() - start_time

    print(f"Generate token number: {max_new_tokens}")
    print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
    print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")


if __name__ == "__main__":
    run_test()


Output:

Generate token number: 100
Generate speed: 41.15903447942894 tokens/sec
Speculative Decoding Spent Time: 2.429600238800049 seconds.
Accept Rate: 0.2987012987012987

Generate token number: 100
Generate speed: 35.865311488261504 tokens/sec
Normal Target Model Decoding Spent Time: 2.7882094383239746 seconds.

This shows an approximate speedup of (41.15 - 35.86) / 35.86 ~= 14.75%.


References


Read More

Leave a ReplyCancel reply

Exit mobile version