Skip to content

整合 Speculative Decoding 和 KV Cache 之實作筆記

Last Updated on 2024-12-17 by Clay

前言

Speculative Decoding 和 KV Cache 都是 Transformers 可以應用的加速技巧;前者是利用一個推理速度較快的 draft model 推測性地生成多個後續的解碼結果並讓希望加速的 target model 進行一次性驗證藉此節省自迴歸解碼的開銷,後者則是應用了 Transformer 因果注意力(Causal Attention)機制中過往 Token 不會看到未來的 Token 的特性,將過去部份 Token 的計算結果保存下來,節省了每次推理時的重複計算。

相關的細節,或可參考我之前寫過的筆記:

而這兩種技術是可以整合在一起進行使用的,以下便是我實作的思路以及過程。


Speculative Decoding 中的 KV Cache

由於 HuggingFace 所開發的 transformers 套件即將在 4.47.0 中捨棄掉舊式以 Tuple 資料型態儲存的 KV Cache,轉而使用自行設計的 DynamicCache() 類別來維護,所以我也在這次的實作中採用了 DynamicCache()。而使用之後,發現其 crop() 方法可以輕易地呼叫用於捨棄不必要的 KV Cache,實在是非常方便。

這樣一來,我們的 KV Cache 的更新(update)可以交給模型推理時自動添加,而我們只需要視 Speculative Decoding 的驗證情況決定是否要截斷部份的 KV Cache。

一般來說,當在 Speculative Decoding 中,我們會遇到以下 4 種情況

  1. 當我們使用 KV Cache 時,我們只需要傳入模型序列的最後一個輸入即可。本來的 input_ids(batch_size, seq_len) 變成 (batch_size, 1)
  2. 當 target model 評估 draft model 的推測解碼(gamma 個推測解碼)時,我們需要輸入的是 (batch_size, gamma+1)
  3. draft model 本來生成的 gamma 個推測解碼,被 target model 拒絕到只剩 k
    • draft model 本來累積的 raw_kv_cache_length + gamma 個序列的 KV Cache,需要剪裁到 raw_kv_cache_length + k
    • draft model 本來累積的 raw_kv_cache_length + gamma + 1 個序列的 KV Cache,需要剪裁到 raw_kv_cache_length + k
  4. 當 target model 全接受 draft model 的解碼時,draft model 要預測下一批推測解碼,其輸入 input_ids 的形狀會是額外多一個 target model 產生的 token,所以是 (batch_size, 2)

只要處理好以上幾種狀況,我們便能實現帶有 KV Cache 的 Speculative Decoding。


實作細節

首先我們需要 import 所有會使用到的套件,其中 sample_next_token 是我自己實作的抽樣函式,具體實現可以參考:大型語言模型的解碼採樣筆記

from typing import Dict, List, Optional, Tuple, Union

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import argparse
import copy
import time

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, Cache, DynamicCache, PreTrainedTokenizerBase

from sampling.sampling import sample_next_token


"""
python speculative_decoding/run_speculative_decoding.py \
    --target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
    --draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
    --device cuda:0 \
    --question 'What is the capital of Taiwan. And why?' \
    --gamma 5 \
    --test_token_num 100
"""


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance


接下來則是 draft model 的抽樣,在這裡我們會根據當前輸入的長度,適時地減少 KV Cache 或是增加 input_ids 輸入的長度(KV Cache 常態的輸入尺寸為 1)。

def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor, Optional[Union[Cache, List[torch.FloatTensor]]]]:
    draft_probs = []

    for idx in range(gamma):
        raw_inputs_ids = inputs.input_ids

        if isinstance(past_key_values, Cache) and past_key_values.get_seq_length() > 0:
            distance = inputs.input_ids.shape[1] - past_key_values.get_seq_length()

            if distance >= 1:
                inputs.input_ids = inputs.input_ids[:, -distance:]
            else:
                past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
                inputs.input_ids = inputs.input_ids[:, -1:]

        with torch.no_grad():
            outputs = draft_model(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                past_key_values=past_key_values,
                use_cache=past_key_values is not None,
            )

        past_key_values = outputs.past_key_values

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs.input_ids,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([raw_inputs_ids, next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs.attention_mask, torch.ones(inputs.attention_mask.shape[0], 1).to(inputs.input_ids.device)], dim=-1)

        inputs.input_ids = input_ids
        inputs.attention_mask = attention_mask

    return inputs, torch.cat(draft_probs, dim=1), past_key_values


接下來則是 target model 的驗證過程,我們也同樣會根據 input_ids 的長度調整 KV Cache;但最不一樣的地方在於,一旦 target model 拒絕了 draft model 的推測解碼,target model 的 KV Cache 也必須 roll-back 回到截斷的位置。

def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
) -> Tuple[Dict[str, torch.Tensor], bool, int, Optional[Union[Cache, List[torch.FloatTensor]]]]:
    raw_inputs_ids = inputs.input_ids

    if isinstance(past_key_values, Cache) and past_key_values.get_seq_length() > 0:
        distance = inputs.input_ids.shape[1] - past_key_values.get_seq_length()
        if distance >= 1:
            inputs.input_ids = inputs.input_ids[:, -distance:]
        else:
            past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
            inputs.input_ids = inputs.input_ids[:, -1:]

    with torch.no_grad():
        outputs = target_model(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            past_key_values=past_key_values,
            use_cache=past_key_values is not None,
        )

    past_key_values = outputs.past_key_values
    inputs.input_ids = raw_inputs_ids

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        diff_probs=draft_probs,
        prefix_token_ids=inputs.input_ids,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs.input_ids[:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)    

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        inputs.input_ids = torch.cat([inputs.input_ids, next_token], dim=-1)
        inputs.attention_mask = torch.cat([inputs.attention_mask, torch.ones(inputs.attention_mask.shape[0], 1).to(inputs.input_ids.device)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs.input_ids.shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs.input_ids[batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs.input_ids[batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs.input_ids[batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs.attention_mask[batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs.input_ids[batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

        inputs.input_ids = input_ids
        inputs.attention_mask = attention_mask

    if isinstance(past_key_values, Cache) and inputs.input_ids.shape[1] <= past_key_values.get_seq_length():
        past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask), past_key_values


最後來實際比較一下使用 KV Cache 與沒有使用在 Speculative Decoding 中的速度差異:

def run_test(args) -> None:
    # Device
    device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Model path 
    target_model_path = args.target_model_path
    draft_model_path = args.draft_model_path

    # Load Tokenizer
    draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
    target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)

    # Load Model
    draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": args.question,
            },
        ],
    ]

    input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    # Warm up the model (CUDA)
    inputs_dummy = {k: v.clone() for k, v in inputs.items()}
    with torch.no_grad():
        draft_model(**inputs_dummy)
        target_model(**inputs_dummy)
    torch.cuda.synchronize()

    is_end = False

    # Record
    raw_inputs = copy.deepcopy(inputs)
    raw_token_num = raw_inputs.input_ids.shape[1]
    start_time = time.time()

    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = args.gamma
    max_new_tokens = args.test_token_num

    draft_past_key_values = None
    target_past_key_values = None

    while not is_end:
        # Draft model
        target_inputs, draft_probs, draft_past_key_values = drafter_speculative_decode(
            draft_model=draft_model,
            draft_tokenizer=draft_tokenizer,
            inputs=inputs,
            gamma=gamma,
            temperature=0,
            past_key_values=draft_past_key_values,
        )

        total_draft_tokens += gamma

        # Target model
        outputs, is_end, accept_tokens, target_past_key_values = target_speculative_decode(
            target_model=target_model,
            target_tokenizer=target_tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
            temperature=0,
            past_key_values=target_past_key_values,
        )

        total_accept_tokens += accept_tokens

        inputs = outputs

        if inputs.input_ids.shape[1] - raw_token_num >= max_new_tokens:
            break

    generate_token_num = outputs.input_ids.shape[1] - raw_token_num
    spent_time = time.time() - start_time

    print(f"(Without KV Cache) Generate token number: {generate_token_num}")
    print(f"(Without KV Cache) Generate speed: {generate_token_num / spent_time} tokens/sec")
    print(f"(Without KV Cache) Speculative Decoding Spent Time: {spent_time} seconds.")
    print(f"(Without KV Cache) Accept Rate: {total_accept_tokens / total_draft_tokens}\n")


    # KV Cache Speculative Decoding
    is_end = False

    # Record
    inputs = copy.deepcopy(raw_inputs)
    raw_token_num = inputs.input_ids.shape[1]
    start_time = time.time()

    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = args.gamma
    max_new_tokens = args.test_token_num

    draft_past_key_values = DynamicCache()
    target_past_key_values = DynamicCache()

    while not is_end:
        # Draft model
        target_inputs, draft_probs, draft_past_key_values = drafter_speculative_decode(
            draft_model=draft_model,
            draft_tokenizer=draft_tokenizer,
            inputs=inputs,
            gamma=gamma,
            temperature=0,
            past_key_values=draft_past_key_values,
        )

        total_draft_tokens += gamma

        # Target model
        outputs, is_end, accept_tokens, target_past_key_values = target_speculative_decode(
            target_model=target_model,
            target_tokenizer=target_tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
            temperature=0,
            past_key_values=target_past_key_values,
        )

        total_accept_tokens += accept_tokens

        inputs = outputs

        if inputs.input_ids.shape[1] - raw_token_num >= max_new_tokens:
            break

    generate_token_num = outputs.input_ids.shape[1] - raw_token_num
    spent_time = time.time() - start_time

    print(f"(KV Cache) Generate token number: {generate_token_num}")
    print(f"(KV Cache) Generate speed: {generate_token_num / spent_time} tokens/sec")
    print(f"(KV Cache) Speculative Decoding Spent Time: {spent_time} seconds.")
    print(f"(KV Cache) Accept Rate: {total_accept_tokens / total_draft_tokens}\n")


Output:

(Without KV Cache) Generate token number: 101
(Without KV Cache) Generate speed: 51.32459081142281 tokens/sec
(Without KV Cache) Speculative Decoding Spent Time: 1.9678676128387451 seconds.
(Without KV Cache) Accept Rate: 0.7719298245614035

(KV Cache) Generate token number: 101
(KV Cache) Generate speed: 62.468003457069095 tokens/sec
(KV Cache) Speculative Decoding Spent Time: 1.6168277263641357 seconds.
(KV Cache) Accept Rate: 0.8035714285714286

我們可以看到確實有加速。


References


Read More

Leave a Reply