Skip to content

Kangaroo 加速推理架構實現筆記

Last Updated on 2024-12-10 by Clay

前言

Kangaroo 是一種引入了可訓練的適配器(Adapter)層的 Self-Speculative Decoding 實現,我最近幾週都在嘗試微調其 Adapter,有了一些初步成果,故紀錄於此。


Kangaroo 架構

Kangaroo 是一種 Self-Speculative Decoding 的變體,而 Self-Speculative Decoding 則又是 Speculative Decoding 的一個分支版本。

在原本的 Speculative Decoding 加速中,我們會有一個 draft model 和一個 target model(最終希望加速的模型),通常 draft model 跟 target model 享有共同的詞彙表,以及快得多的推理速度;之後,由 draft model 推測性地產生出數個 Tokens,再由 target model 一次性驗證。

在原始論文(Fast Inference from Transformers via Speculative Decoding)中,研究團隊有針對 target model 的驗證演算法特別設計,以保證通過 draft model 推測解碼後能夠與原始模型 target model 保持同樣的機率分佈輸出,達成不影響原始模型解碼結果的效果。

而 Self-Speculative Decoding(Draft & Verify: Lossless Large Language Model Acceleration via Self-Speculative Decoding)則是透過 target model 自身的部份網路當作 draft model。簡單來說,我們在 Transformer Block 中每一層都會有注意力層和 MLP 層,並且我們可以選擇是否要跳過這些層。

這些層是否要選擇跳過是有講究的,一來跳得太少,draft model 的速度提昇不了,也就沒辦法替 target model 加速;反之,若是我們跳得太多,draft model 可能推測的解碼就會與 target model 偏離非常遠,可能根本沒辦法被 target model 接受。

所以為了兼顧跳層的推理速度與正確性,Self-Speculative Decoding 採用了貝氏優化的方式去探索究竟怎麼樣的跳層策略能夠滿足以上兩項需求。

我之前也有做了一番實驗:Self-Speculative Decoding 完整實作: LayerSkip Model, Bayesian Optimization, and Adaptive Draft-Exiting Mechanism(附 gemma-2-9b-it 實驗結果)

不過,當初的實驗結果可說是慘不忍睹。由於我的硬體配置限制,我僅僅只能測試到 9B 量級為止,不過實際上根本加速不了 —— 難怪當初論文放的結果都至少從 13B 量級開始,還真有可能是因為模型量級太少導致用部份神經網路組成的 draft model 效果不佳。

於是我就又開始挑戰 Kangaroo 的實作了。Kangaroo 的精神與 Self-Speculative Decoding 一致,但是它是允許模型進行『訓練』的。當然如果我們訓練了 draft model,我們就不能保證 target model 的性能了,畢竟 draft model 也是 target model 的一部分,有可能會導致 target model 在微調後失去部份能力。

所以 Kangaroo 提出了以下架構:

Kangaroo 的 draft model 基本會由淺層網路(shallow sub-network)組成,再通過額外添加可訓練的 Adapter-Network 得到隱藏狀態輸出,最後再把隱藏狀態輸出傳入 LM Head 取得最後的 logits。

而這個 Adapter-Network 是可以訓練的,並且在實際訓練中,我們會拿 Adapter-Network 和 Remaining Layers 最後一層的輸出都通過 LM Head 取得詞彙表尺存的 logits 輸出,並藉由梯度下降讓兩者的 logits 輸出越像越好。

換個方向來想,也可以想像成 draft model 是提早離開模型推理流程而直接跳到最後 LM Head 的另外一條路徑,然而直接跳到最後一層會導致 LLM 解碼出奇怪的 Token(這很正常,畢竟還是淺層網路而已,在原始 pretraining 中根本沒預計會讓淺層網路結果直接接到最後一層),所以我們還需要額外有一個 Adapter 來作為淺層網路輸出和 LM Head 之間的橋接器。

這樣一來,我們就可以在不影響原始模型的情況下,將由淺層網路組成的 draft model 微調得更接近 target model。

先講結論,在 Self-Speculative Decoding 的實驗中我怎麼樣都沒辦法讓 8B 量級的模型進行加速、但是在 Kangaroo 訓練的最後,我總算是透過 Kangaroo 的方法提昇了約 10% ~ 15% 的加速推理了。老實說經歷了 Self-Speculative Decoding 和 Kangaroo 早期實驗的挫敗,最後能看到這個結果已經足夠我老淚縱橫了 XDDD


實作原始碼

我關於 Kangaroo 實作的程式碼還挺多的,詳細的實作可以參考我的 GitHub: https://github.com/ccs96307/fast-llm-inference,這裡放有我各式各樣加速推理的實現,也歡迎隨時跟我聯絡討論。

在這裡,我就簡單介紹 3 個腳本:

  • modeling_kangaroo_llama3.py: 這個是我基於 Llama 架構所修改的 Kangaroo 模型
  • train.py: 訓練用的腳本
  • test_kangaroo.py: 測試 Kangaroo 模型帶來的加速效果


首先是 Kangaroo 的架構實現,這裡我先抽象地描述:我繼承了 LlamaForCausalLM 這一模型,並從原本的 self.model.layers 中分別抽出了 shallow_layers 和 remaining_layers,按照原始論文設計,我的淺層網路實際上只有兩層。

而在 Adapter 的部份我則暫時是實現了 Attention 模組的版本(論文中的原始方法)以及添加上了 MLP 模組的完整 decoder layer 模組的版本。這是因為我本來在 Attention 上訓練效果收斂得不好,所以加上更多參數希望能得到更好的結果。

而在 loss function 的設計上,我這邊也額外多了一個 hard_labels 的損失,並且可以藉由乘上一個權重與本來論文中的 soft labels 版本的 cross-entropy 同時使用。在我的實驗中,這樣的效果似乎更好一點,不過實際上我並沒能復現到原始論文中那麼好的結果,可能有哪邊還有潛在的實現不夠好的部份。

另外值得一提的是,我有實作一個 prepare_dataset() 的方法,這個方法是搭配 Kangaroo 官方專案的實作,只抽取 early exit layer 和最後一層的輸出作為訓練資料,以節省模型訓練時的中間推理過程的優化;不過在我的許多台 server 上,使用 sharegpt4 資料集會需要在硬碟儲存接近 900 GB 的空間,所以我優化的訓練方法、跟原生直接把 draft model 和 target model 都跑一遍的兩個方法都寫出來了。

from typing import Any, List, Optional, Tuple, Union

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))

from dataclasses import dataclass
import json

import torch
torch.set_default_dtype(torch.bfloat16)

from torch.nn.utils.rnn import pad_sequence
from transformers import LlamaForCausalLM
from transformers.models.llama.modeling_llama import Cache, DynamicCache, LlamaSdpaAttention, LlamaDecoderLayer, LlamaRMSNorm
from transformers.modeling_outputs import CausalLMOutputWithPast

from sampling.sampling import sample_next_token
from utils.utils import calculate_continuous_acceptance


@dataclass
class KangarooModelMode:
    draft_only_mode: str = "draft_only"
    target_only_mode: str = "target_only"
    train_mode: str = "train"


@dataclass
class AdapterMode:
    attention_only_mode: str = "attention_only"
    mlp_only_mode: str = "mlp_only"
    decoder_layer_mode: str = "decoder_layer"


class KangarooLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.config = config

        self.adapter_layer_mode = None
        self.draft_mode_adapter_layer = None
        self.shallow_layers = None
        self.mode = KangarooModelMode.target_only_mode
        self.confidence_threshold = 0.5
        self.accept_rate = 0
        self.total_accept_tokens = 0
        self.total_draft_generated_token = 0
        self.draft_temperature = 1.0
        self.target_temperature = 1.0
        self.alpha = 0.8
        self.shallow_layer_num = 10
    
    def set_skip_layer(self, shallow_layer_num: int) -> None:
        self.shallow_layer_num = shallow_layer_num
        self.shallow_layers = self.model.layers[:shallow_layer_num]
        self.remaining_layers = self.model.layers[shallow_layer_num:]

    def set_adapter_layer(self, _mode: str) -> None:
        self.adapter_layer_mode = _mode

        if _mode == AdapterMode.attention_only_mode:
            self.attn_input_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
            self.attn_output_norm = LlamaRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)

            self.draft_mode_adapter_layer = LlamaSdpaAttention(
                config=self.config,
                layer_idx=self.config.num_hidden_layers,
            )
        elif _mode == AdapterMode.decoder_layer_mode:
            self.draft_mode_adapter_layer = LlamaDecoderLayer(
                config=self.config,
                layer_idx=self.config.num_hidden_layers,
            )

    def set_draft_mode(self) -> None:
        self.mode = KangarooModelMode.draft_only_mode

    def set_target_mode(self) -> None:
        self.mode = KangarooModelMode.target_only_mode

    def set_train_mode(self) -> None:
        self.mode = KangarooModelMode.train_mode

    def save_head(
        self,
        save_dir: str,
    ) -> None:
        """
        Save the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num to the specified directory.

        Args:
            save_dir (str): Directory to save the adapter parameters.
        """
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Save the adapter's state dictionary
        head_path = os.path.join(save_dir, "lm_head.pt")
        torch.save(self.lm_head.state_dict(), head_path)
        print(f"`lm_head` saved at {head_path}")

    def save_norm(
        self,
        save_dir: str,
    ) -> None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Save the adapter's state dictionary
        norm_path = os.path.join(save_dir, "norm.pt")
        torch.save(self.model.norm.state_dict(), norm_path)
        print(f"`norm` saved at {norm_path}")
    
    def save_adapter(
        self,
        save_dir: str,
        train_loss_history: List[float],
        eval_loss_history: List[float],
    ) -> None:
        """
        Save the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num to the specified directory.

        Args:
            save_dir (str): Directory to save the adapter parameters.
        """
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        
        # Save the adapter's state dictionary
        adapter_path = os.path.join(save_dir, "draft_adapter.pt")
        torch.save(self.draft_mode_adapter_layer.state_dict(), adapter_path)
        print(f"Draft adapter saved at {adapter_path}")

        # Save additional information (loss_history and shallow_layer_num)
        metadata = {
            "train_loss_history": train_loss_history,
            "eval_loss_history": eval_loss_history,
            "shallow_layer_num": self.shallow_layer_num
        }
        metadata_path = os.path.join(save_dir, "adapter_metadata.json")
        with open(metadata_path, "w") as f:
            json.dump(metadata, f)
        print(f"Adapter metadata saved at {metadata_path}")

    def load_adapter(self, load_dir: str) -> None:
        """
        Load the parameters of the draft_mode_adapter_layer, loss_history, and shallow_layer_num from the specified directory.

        Args:
            load_dir (str): Directory to load the adapter parameters from.

        Raises:
            FileNotFoundError: If the adapter file does not exist in the specified directory.
        """
        adapter_path = os.path.join(load_dir, "draft_adapter.pt")
        if not os.path.exists(adapter_path):
            raise FileNotFoundError(f"Draft adapter not found at {adapter_path}")
        
        # Load the adapter's state dictionary
        state_dict = torch.load(adapter_path, map_location=self.device)
        self.draft_mode_adapter_layer.load_state_dict(state_dict=state_dict)
        print(f"Draft adapter loaded from {adapter_path}")

    def prepare_dataset(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if self.shallow_layers is None:
            raise AttributeError(f"You do not set the `shallow_layers`!")
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Model
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.model.gradient_checkpointing and self.model.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)

        # kept for BC (non `Cache` `past_key_values` inputs)
        if use_cache and not isinstance(past_key_values, Cache):
            if past_key_values is None:
                past_key_values = DynamicCache()
            else:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self.model._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )
        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

        # decoder layers
        for decoder_layer in self.shallow_layers:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

            hidden_states = layer_outputs[0]

        shallow_hidden_states = hidden_states

        # Remaining decoder layers
        for decoder_layer in self.remaining_layers:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.model.norm(hidden_states)

        return shallow_hidden_states, hidden_states

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        num_logits_to_keep: int = 0,
        **loss_kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        if self.mode == KangarooModelMode.target_only_mode:
            return super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                labels=labels,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                cache_position=cache_position,
                num_logits_to_keep=num_logits_to_keep,
                **loss_kwargs,
            )
        elif self.mode == KangarooModelMode.draft_only_mode:
            if self.shallow_layers is None:
                raise AttributeError(f"You do not set the `shallow_layers`!")

            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict

            # Model
            if (input_ids is None) ^ (inputs_embeds is not None):
                raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

            if self.model.gradient_checkpointing and self.model.training and use_cache:
                use_cache = False

            if inputs_embeds is None:
                inputs_embeds = self.model.embed_tokens(input_ids)

            # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = False
            if use_cache and not isinstance(past_key_values, Cache):
                return_legacy_cache = True
                if past_key_values is None:
                    past_key_values = DynamicCache()
                else:
                    past_key_values = DynamicCache.from_legacy_cache(past_key_values)

            if cache_position is None:
                past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                cache_position = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                )
            if position_ids is None:
                position_ids = cache_position.unsqueeze(0)

            causal_mask = self.model._update_causal_mask(
                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
            )
            hidden_states = inputs_embeds

            # create position embeddings to be shared across the decoder layers
            position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

            # decoder layers
            all_hidden_states = () if output_hidden_states else None
            all_self_attns = () if output_attentions else None
            next_decoder_cache = None

            for decoder_layer in self.shallow_layers:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            # add hidden states from the last decoder layer
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            next_cache = next_decoder_cache if use_cache else None
            if return_legacy_cache:
                next_cache = next_cache.to_legacy_cache()

            # Adapter
            if self.adapter_layer_mode == AdapterMode.attention_only_mode:
                residual = hidden_states
                hidden_states = self.attn_input_norm(hidden_states)

                hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
                    hidden_states=hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )
                hidden_states = residual + hidden_states
                hidden_states = self.attn_output_norm(hidden_states)

            elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
                layer_outputs = self.draft_mode_adapter_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )
                hidden_states = layer_outputs[0]
                hidden_states = self.model.norm(hidden_states)

            if self.config.pretraining_tp > 1:
                lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
                logits = [torch.nn.functional.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
                logits = torch.cat(logits, dim=-1)
            else:
                # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
                logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])

            loss = None
            if labels is not None:
                loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)

            return CausalLMOutputWithPast(
                loss=loss,
                logits=logits,
                past_key_values=past_key_values,
                hidden_states=hidden_states,
                attentions=all_self_attns,
            )
        elif self.mode == KangarooModelMode.train_mode:
            output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
            output_hidden_states = (
                output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
            )
            return_dict = return_dict if return_dict is not None else self.config.use_return_dict

            if (input_ids is None) ^ (inputs_embeds is not None):
                raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

            if self.model.gradient_checkpointing and self.model.training and use_cache:
                use_cache = False

            if inputs_embeds is None:
                inputs_embeds = self.model.embed_tokens(input_ids)

            # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = False
            if use_cache and not isinstance(past_key_values, Cache):
                return_legacy_cache = True
                if past_key_values is None:
                    past_key_values = DynamicCache()
                else:
                    past_key_values = DynamicCache.from_legacy_cache(past_key_values)

            if cache_position is None:
                past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                cache_position = torch.arange(
                    past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                )
            if position_ids is None:
                position_ids = cache_position.unsqueeze(0)

            causal_mask = self.model._update_causal_mask(
                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
            )
            hidden_states = inputs_embeds

            # create position embeddings to be shared across the decoder layers
            position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

            # decoder layers
            all_hidden_states = () if output_hidden_states else None
            all_self_attns = () if output_attentions else None
            next_decoder_cache = None

            for decoder_layer in self.shallow_layers:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            # Cache hidden states
            remaining_hidden_states = hidden_states

            # Add hidden states from the last decoder layer
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            next_cache = next_decoder_cache if use_cache else None
            if return_legacy_cache:
                next_cache = next_cache.to_legacy_cache()

            # Adapter
            if self.adapter_layer_mode == AdapterMode.attention_only_mode:
                residual = hidden_states
                hidden_states = self.attn_input_norm(hidden_states)

                hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
                    hidden_states=hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                hidden_states = residual + hidden_states
                hidden_states = self.attn_output_norm(hidden_states)

            elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
                layer_outputs = self.draft_mode_adapter_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )
                hidden_states = layer_outputs[0]
                hidden_states = self.model.norm(hidden_states)

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            # Remaining decoder layers
            for decoder_layer in self.remaining_layers:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = decoder_layer(
                    remaining_hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                remaining_hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            remaining_hidden_states = self.model.norm(remaining_hidden_states)

            # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
            draft_logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) / self.draft_temperature
            target_logits = self.lm_head(remaining_hidden_states[:, -num_logits_to_keep:, :]) / self.target_temperature

            # Compute the log probabilities for both models
            draft_log_probs = torch.nn.functional.log_softmax(draft_logits, dim=-1)
            target_log_probs = torch.nn.functional.log_softmax(target_logits, dim=-1)
            target_probs = torch.nn.functional.softmax(target_logits, dim=-1)

            # Cross-entropy loss between target and draft model predictions
            # kl_loss = torch.nn.functional.kl_div(draft_log_probs, target_probs, reduction="batchmean")
            hard_labels = torch.argmax(target_probs, dim=-1)
            soft_label_cross_entropy_loss = -(target_probs * draft_log_probs).sum(dim=-1).mean()
            hard_label_loss = torch.nn.functional.cross_entropy(
                draft_logits.view(-1, draft_logits.size(-1)),  # Flatten logits
                hard_labels.view(-1)  # Flatten hard labels
            )

            loss = self.alpha * soft_label_cross_entropy_loss + (1 - self.alpha) * hard_label_loss

            return CausalLMOutputWithPast(
                loss=loss,
                logits=target_logits,
                past_key_values=past_key_values,
                hidden_states=hidden_states,
                attentions=all_self_attns,
            )

    @torch.no_grad()
    def kangaroo_generate(
        self,
        eos_token_id: int,
        pad_token_id: int,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        max_new_tokens: int = 100,
        **kwargs,
    ) -> torch.LongTensor:
        if self.shallow_layers is None:
            raise AttributeError(f"You do not set the `shallow_layers`!")

        confidence_score = 1.0
        total_generate_tokens = 0
        
        while total_generate_tokens < max_new_tokens:
            draft_generate_tokens = 0
            draft_probs = []

            while confidence_score >= self.confidence_threshold:
                output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
                output_hidden_states = (
                    output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
                )
                return_dict = return_dict if return_dict is not None else self.config.use_return_dict

                if (input_ids is None) ^ (inputs_embeds is not None):
                    raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

                if self.model.gradient_checkpointing and self.model.training and use_cache:
                    use_cache = False

                if inputs_embeds is None:
                    inputs_embeds = self.model.embed_tokens(input_ids)

                # kept for BC (non `Cache` `past_key_values` inputs)
                return_legacy_cache = False
                if use_cache and not isinstance(past_key_values, Cache):
                    return_legacy_cache = True
                    if past_key_values is None:
                        past_key_values = DynamicCache()
                    else:
                        past_key_values = DynamicCache.from_legacy_cache(past_key_values)

                if cache_position is None:
                    past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
                    cache_position = torch.arange(
                        past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
                    )
                if position_ids is None:
                    position_ids = cache_position.unsqueeze(0)

                causal_mask = self.model._update_causal_mask(
                    attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
                )
                hidden_states = inputs_embeds

                # create position embeddings to be shared across the decoder layers
                position_embeddings = self.model.rotary_emb(hidden_states, position_ids)

                # decoder layers
                all_hidden_states = () if output_hidden_states else None
                all_self_attns = () if output_attentions else None
                next_decoder_cache = None

                for decoder_layer in self.shallow_layers:
                    if output_hidden_states:
                        all_hidden_states += (hidden_states,)

                    layer_outputs = decoder_layer(
                        hidden_states,
                        attention_mask=causal_mask,
                        position_ids=position_ids,
                        past_key_value=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                    )

                    hidden_states = layer_outputs[0]

                    if use_cache:
                        next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                    if output_attentions:
                        all_self_attns += (layer_outputs[1],)

                # Cache hidden states
                remaining_hidden_states = hidden_states

                # Add hidden states from the last decoder layer
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                next_cache = next_decoder_cache if use_cache else None
                if return_legacy_cache:
                    next_cache = next_cache.to_legacy_cache()

                # Adapter
                if self.adapter_layer_mode == AdapterMode.attention_only_mode:
                    residual = hidden_states
                    hidden_states = self.attn_input_norm(hidden_states)
                    
                    hidden_states, all_self_attns, past_key_values = self.draft_mode_adapter_layer(
                        hidden_states=hidden_states,
                        attention_mask=attention_mask,
                        position_ids=position_ids,
                        past_key_values=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                    )

                    hidden_states = residual + hidden_states
                    hidden_states = self.attn_output_norm(hidden_states)

                elif self.adapter_layer_mode == AdapterMode.decoder_layer_mode:
                    layer_outputs = self.draft_mode_adapter_layer(
                        hidden_states,
                        attention_mask=causal_mask,
                        position_ids=position_ids,
                        past_key_value=past_key_values,
                        output_attentions=output_attentions,
                        use_cache=use_cache,
                        cache_position=cache_position,
                        position_embeddings=position_embeddings,
                    )
                    hidden_states = layer_outputs[0]
                    hidden_states = self.model.norm(hidden_states)

                    if use_cache:
                        next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                    if output_attentions:
                        all_self_attns += (layer_outputs[1],)

                # Re-init
                inputs_embeds = None
                position_ids = None
                cache_position = None

                # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
                draft_logits = self.lm_head(hidden_states[:, -1:, :])

                # Sampling and get the probabilities
                next_tokens, probs = sample_next_token(
                    logits=draft_logits,
                    prefix_token_ids=input_ids,
                )

                draft_probs.append(probs)
                input_ids = torch.cat([input_ids, next_tokens[:, -1:]], dim=-1)
                attention_mask = torch.cat([attention_mask, torch.ones(attention_mask.shape[0], 1).to(input_ids.device)], dim=-1)

                draft_generate_tokens += 1
                self.total_draft_generated_token += 1

                # Support bs=1
                decode_token_id = next_tokens[:, -1].item()
                if probs[:, -1, decode_token_id] < self.confidence_threshold or total_generate_tokens + draft_generate_tokens >= max_new_tokens:
                    draft_probs = torch.cat(draft_probs, dim=1)
                    break

            # Use whole model for evaluating
            for decoder_layer in self.remaining_layers:
                if output_hidden_states:
                    all_hidden_states += (hidden_states,)

                layer_outputs = decoder_layer(
                    remaining_hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

                remaining_hidden_states = layer_outputs[0]

                if use_cache:
                    next_decoder_cache = layer_outputs[2 if output_attentions else 1]

                if output_attentions:
                    all_self_attns += (layer_outputs[1],)

            remaining_hidden_states = self.model.norm(remaining_hidden_states)
            num_logits_to_keep = draft_probs.shape[1]
            target_logits = self.lm_head(remaining_hidden_states[:, -num_logits_to_keep:, :])

            target_input_ids = input_ids[:, :-1]

            next_tokens, target_probs = sample_next_token(
                logits=target_logits,
                prefix_token_ids=target_input_ids,
                probs_num=num_logits_to_keep,
            )

            # Evaluation
            expanded_indices = input_ids[:, -draft_probs.shape[1]:].unsqueeze(-1)

            # Get each probilities
            selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices).squeeze(-1)
            selected_eval_probs = torch.gather(target_probs, dim=-1, index=expanded_indices).squeeze(-1)

            # Compare draft_prob and eval_prob, and check the reject_mask
            mask_to_reject = selected_draft_probs > selected_eval_probs

            # Calculate reject probabilty 1 - (eval_prob / draft_prob)
            rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

            # Generate random values to determined accept or reject
            random_values = torch.rand_like(rejection_probs)
            rejection_decisions = random_values < rejection_probs

            # Get the final reject masks
            rejection_masks = mask_to_reject & rejection_decisions
            acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
            acceptance_mask[rejection_masks] = False

            # Concat `input_ids`
            if torch.all(acceptance_mask):
                total_generate_tokens += draft_generate_tokens
            else:
                new_input_ids = []
                new_attention_mask = []
                is_end = False

                for batch_idx in range(next_tokens.shape[0]):
                    gamma = next_tokens.shape[1]
                    start_idx = input_ids.shape[1] - gamma

                    for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                        total_generate_tokens += 1
                        if (acceptance_mask[batch_idx][pos_idx] and input_ids[batch_idx][start_idx+pos_idx].item() == eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                            input_ids[batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                            new_input_ids.append(input_ids[batch_idx][:start_idx+pos_idx+1])
                            new_attention_mask.append(attention_mask[batch_idx][:start_idx+pos_idx+1])
                            
                            is_end = input_ids[batch_idx][start_idx+pos_idx].item() == eos_token_id
                            break

                input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=pad_token_id)
                attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

            self.total_accept_tokens += calculate_continuous_acceptance(acceptance_mask=acceptance_mask)
            self.accept_rate = self.total_accept_tokens / self.total_draft_generated_token
        
            if is_end:
                break

        return {"input_ids": input_ids}

而在訓練腳本中,就沒有太多值得細講的部份,就只是很正規地宣告了資料集與模型進行訓練。

from typing import Dict, List

import os
from dataclasses import dataclass

from datasets import load_dataset

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.model_selection import train_test_split

from kangaroo_modeling.modeling_kangaroo_llama3 import KangarooLlamaForCausalLM


class CustomDataset(Dataset):
    def __init__(self, inputs: Dict[str, torch.LongTensor], device: torch.DeviceObjType):
        self.inputs = inputs
        self.device = device

    def __len__(self) -> int:
        return self.inputs.input_ids.shape[0]
    
    def __getitem__(self, index: int):
        return (
            self.inputs.input_ids[index].to(self.device),
            self.inputs.attention_mask[index].to(self.device),
        )
    

def main() -> None:
    # Settings
    epochs = 100
    batch_size = 4
    max_length = 512
    lr = 5e-5
    shallow_layer_num = 2
    adapter_mode = "attention_only"
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

    # Load model and tokenizer
    pretrained_model_name_or_path = "../models/meta-llama--Meta-Llama-3.1-8B-Instruct"
    # pretrained_model_name_or_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct"
    model = KangarooLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token

    model.set_skip_layer(shallow_layer_num=shallow_layer_num)
    model.set_adapter_layer(_mode=adapter_mode)
    model.set_train_mode()

    model = model.to(device)

    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False

    # Unfreeze adapter layer
    model.draft_mode_adapter_layer.train()
    for param in model.draft_mode_adapter_layer.parameters():
        param.requires_grad = True

    if hasattr(model, "attn_input_norm"):
        print("Attention-Adapter!")
        
        for param in model.attn_input_norm.parameters():
            param.requires_grad = True

        for param in model.attn_output_norm.parameters():
            param.requires_grad = True

    # Load dataset
    dataset = load_dataset("shibing624/sharegpt_gpt4")
    
    samples = dataset["train"]["conversations"]
    samples = [[{"role": sample[0]["from"].replace("human", "user").replace("gpt", "assistant"), "content": sample[0]["value"]}] for sample in samples]

    train_samples, eval_samples = train_test_split(samples, test_size=0.1, random_state=2999)
    print(len(samples))

    # Tokenized
    train_inputs = tokenizer(
        [tokenizer.apply_chat_template(messages, tokenize=False) for messages in train_samples],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    )

    eval_inputs = tokenizer(
        [tokenizer.apply_chat_template(messages, tokenize=False) for messages in eval_samples],
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    )

    train_dataset = CustomDataset(inputs=train_inputs, device=device)
    eval_dataset = CustomDataset(inputs=eval_inputs, device=device)

    # Dataloader
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

    # Optimizer
    optimizer = torch.optim.AdamW(model.draft_mode_adapter_layer.parameters(), lr=lr)

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        train_loss_history = []
        eval_loss_history = []

        for batch_idx, batch in enumerate(train_dataloader, 1):
            input_ids, attention_mask = batch

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                num_logits_to_keep=max_length,
            )

            # Zero gradients
            optimizer.zero_grad()

            # Calculate loss
            loss = outputs.loss
            total_loss += loss.item()
            train_loss_history.append(loss.item())

            # Backward pass
            loss.backward()

            # Optimizer step
            optimizer.step()

            # Log training loss
            avg_loss = total_loss / batch_idx
            print(f"Train - Epoch [{epoch + 1}/{epochs}] Steps [{batch_idx}/{len(train_dataloader)}], Training Loss: {avg_loss:.4f}")

        # Evaluate the model
        model.eval()
        eval_loss = 0
        with torch.no_grad():
            for batch_idx, batch in enumerate(eval_dataloader, 1):
                input_ids, attention_mask = batch

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    num_logits_to_keep=max_length,
                )
                eval_loss += outputs.loss.item()
                eval_loss_history.append(outputs.loss.item())

                avg_loss = eval_loss / batch_idx
                print(f"Eval - Epoch [{epoch + 1}/{epochs}] Steps [{batch_idx}/{len(eval_dataloader)}], Eval Loss: {avg_loss:.4f}")

        # Save model checkpoint
        save_dir = "./checkpoints/checkpoints_hce_attn_20241209/"
        save_path = os.path.join(save_dir, f"epoch_{epoch+1}")
        model.save_adapter(
            save_path,
            train_loss_history=train_loss_history,
            eval_loss_history=eval_loss_history,
        )
        print(f"Adapter checkpoint saved at {save_path}")


if __name__ == "__main__":
    main()


在訓練結束後,我嘗試繪製出我的 loss:

可以看到 Eval Loss 其實在約 0.88 左右就一直降不下去了,而在這樣的情況下,我們可以來實際測試 Kangaroo 的架構在訓練後究竟加速多少。

from typing import Dict, List, Optional, Tuple

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import copy
import time

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from kangaroo_modeling.modeling_kangaroo_llama3 import KangarooLlamaForCausalLM
from sampling.sampling import sample_next_token


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance

def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_probs = []

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

    return inputs, torch.cat(draft_probs, dim=1)


def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        diff_probs=draft_probs,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)    

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


def run_test() -> None:
    # Device
    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Model path 
    pretrained_model_name_or_path = "../models/meta-llama--Meta-Llama-3.1-8B-Instruct"
    adapter_dir = "checkpoints/checkpoints_hce_decoder_layer_20241205/epoch_45/"

    # Load Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token

    # Load Model
    model = KangarooLlamaForCausalLM.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.bfloat16)
    model.set_skip_layer(shallow_layer_num=2)
    model.set_adapter_layer("decoder_layer")
    if adapter_dir:
        model.load_adapter(adapter_dir)

    model = model.to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]

    input_text=tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    # Warm up the model (CUDA)
    inputs_dummy = {k: v.clone() for k, v in inputs.items()}
    with torch.no_grad():
        model.set_draft_mode()
        model(**inputs_dummy)
        model.set_target_mode()
        model(**inputs_dummy)
    torch.cuda.synchronize()

    # Record
    raw_inputs = copy.deepcopy(inputs)
    raw_token_num = raw_inputs["input_ids"].shape[1]
    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = 1
    max_new_tokens = 100
    is_end = False

    start_time = time.time()

    while not is_end:
        # Draft model
        model.set_draft_mode()
        target_inputs, draft_probs = drafter_speculative_decode(
            draft_model=model,
            draft_tokenizer=tokenizer,
            inputs=inputs,
            gamma=gamma,
            temperature=0,
        )

        total_draft_tokens += gamma

        # Target model
        model.set_target_mode()
        outputs, is_end, accept_tokens = target_speculative_decode(
            target_model=model,
            target_tokenizer=tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
            temperature=1,
        )

        total_accept_tokens += accept_tokens

        inputs = outputs

        if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
            break

    generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
    spent_time = time.time() - start_time

    print(f"Generate token number: {generate_token_num}")
    print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
    print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
    print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")

    # Normal Target Model Speed
    inputs = copy.deepcopy(raw_inputs)
    start_time = time.time()
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=model,
        draft_tokenizer=tokenizer,
        inputs=inputs,
        gamma=max_new_tokens,
    )

    spent_time = time.time() - start_time

    print(f"Generate token number: {max_new_tokens}")
    print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
    print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")


if __name__ == "__main__":
    run_test()


Output:

Generate token number: 100
Generate speed: 41.15903447942894 tokens/sec
Speculative Decoding Spent Time: 2.429600238800049 seconds.
Accept Rate: 0.2987012987012987

Generate token number: 100
Generate speed: 35.865311488261504 tokens/sec
Normal Target Model Decoding Spent Time: 2.7882094383239746 seconds.

可以看到有約 (41.15 - 35.86) / 35.86 ~= 14.75% 的加速。


References


Read More

Leave a Reply取消回覆

Exit mobile version