Skip to content

Speculative Decoding 時採用目標模型(Target Model)的信心閾值來決定是否啟用草稿推測

Last Updated on 2024-11-22 by Clay

目前我看的許多加速推理技巧,如 Speculative Decoding 等等方式,大多數都是採用把 draft model 信心分數設定一個閾值(threshold)來決定現在要解碼多少個 draft tokens、再交由 target model 進行驗證,以此來減少 draft model 在低信心程度的情況下額外多推測的時間開銷。

這個方法雖然確實能夠透過動態來增進效能,但我在第一次看到時其實心中就有一個疑慮:如果今天我的 draft model 與 target model 的能力確實有落差、尤其 draft model 理解能力不夠的情況下,就算 draft model 的信心分數足夠高,我依然該相信它嗎?

所以反過來說,會不會我們採用 target model 的信心分數來當作啟用草稿推測的依據會更合理呢?既然我們會懷疑性能較弱的 draft model 高信心分數不可靠(畢竟做出接受決定的是 target model),那麼 target model 的低信心分數應該有充足理由讓我們不把當前的解碼任務交給 draft model 吧?

更白話一點:我假設 draft model 的高信心分數可能是自我感覺良好,不應構成我們讓它繼續生成的理由;反之若是 target model 的信心分數很低,則代表著自己都沒把握解碼得不錯,所以更不該交給 draft model 去進行推測。

先說結論,可能是我孤陋寡聞與學習涉獵不夠,所以還沒看到有相關論文討論到底取用 draft model 的信心分數決定繼續解碼、還是由 target model 的信心分數決定啟用草稿推測模式 —— Which one is better。但是在我測試的單純情況中,我看到很明確也很直接的加速,於是紀錄於此,權作一些自己研究加速推理技巧的參考。


背景知識

首先,歡迎持續關注我的 GitHub,我的許多實作都會放在這個 fast-llm-inference 的實作專案裡面:https://github.com/ccs96307/fast-llm-inference

這篇筆記嚴格說起來還有上一篇,那就是:推測性解碼(Speculative Decoding)實作筆記(附簡易實驗結果)

這上一篇文章中,我紀錄了我自己實現的採樣方式以及 Speculative Decoding 並附上了一段簡易的實驗結果。現在,我將基於這段程式碼進行改進,並比較加入 target model 信心分數作為判斷是否啟用草稿驗證模式的加速實驗結果。

我用當時的程式碼跑了幾次實驗,得到的結果大抵都相去不遠:

Generate token number: 100
Generate speed: 34.81801971715937 tokens/sec
Speculative Decoding Spent Time: 2.8720760345458984 seconds.
Accept Rate: 0.34054054054054056

Generate token number: 100
Generate speed: 28.07058497562908 tokens/sec
Normal Target Model Decoding Spent Time: 3.562448024749756 seconds.

Generate token number: 100
Generate speed: 94.92253307563831 tokens/sec
Normal Draft Model Decoding Spent Time: 1.0534906387329102 seconds.


簡單講結論就是:

  • draft model 的解碼速度:94.92 tokens/sec
  • target model 的解碼速度:28.07 tokens/sec
  • Speculative Decoding 的解碼速度:34.82 tokens/sec

大約實現了 1.24x 的加速。


使用 target model 的信心分數決定是否要啟用 draft model 進行推測

本著如果當前解碼對我來說沒有把握,那我就不要麻煩更弱的 draft model 來做 —— 我進行了以下的改動:

def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    confidence_score = 0

    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
        confidence_score = target_probs[:, -1, next_token[0][0]].item()
        print(f"Confidence for next token: {confidence_score:.4f}")
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
                    confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
                    print(f"Replacement Confidence for next token: {confidence_score:.4f}")

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    # Keep generating if confidence_score is less than confidence threshold
    while confidence_score < 0.5:
        with torch.no_grad():
            outputs = target_model(**inputs)

        next_tokens, target_probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            probs_num=1,
        )

        # Update `confidence_score`
        next_token = next_tokens[:, -1:]
        confidence_score = target_probs[0, -1, next_token[0][0]].item()
        print(f"keep generate confidence_score: {confidence_score:.4f}")

        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

        is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
        if is_end:
            break

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


跟原本的程式碼相比,我在 target model 驗證時反覆取得信心分數,這對採樣結束後已經取得機率分佈的我的實作來說幾乎沒有額外時間開銷了。


附上我的完整原始碼與實驗配置:

  • GPU: RTX 4060 8GB
  • OS: Ubuntu 22.04
  • Target model: HuggingFaceTB/SmolLM2-1.7B-Instruct
  • Draft model: HuggingFaceTB/SmolLM2-135M-Instruct
from typing import Dict, List, Optional, Tuple

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import argparse
import copy
import time

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from sampling.sampling import sample_next_token


"""
python speculative_decoding/run_speculative_decoding.py \
    --target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
    --draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
    --device cuda:0 \
    --question 'What is the capital of Taiwan. And why?' \
    --gamma 5 \
    --test_token_num 100 
"""


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance


def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_probs = []

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

    return inputs, torch.cat(draft_probs, dim=1)


def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    confidence_score = 0

    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
        confidence_score = target_probs[:, -1, next_token[0][0]].item()
        print(f"Confidence for next token: {confidence_score:.4f}")
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
                    confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
                    print(f"Replacement Confidence for next token: {confidence_score:.4f}")

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    # Keep generating if confidence_score is less than confidence threshold
    while confidence_score < 0.5:
        with torch.no_grad():
            outputs = target_model(**inputs)

        next_tokens, target_probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            probs_num=1,
        )

        # Update `confidence_score`
        next_token = next_tokens[:, -1:]
        confidence_score = target_probs[0, -1, next_token[0][0]].item()
        print(f"keep generate confidence_score: {confidence_score:.4f}")

        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

        is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
        if is_end:
            break

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


def run_test(args) -> None:
    # Device
    device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
    print(device)

    # Model path 
    target_model_path = args.target_model_path
    draft_model_path = args.draft_model_path

    # Load Tokenizer
    draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
    target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)

    # Load Model
    draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": args.question,
            },
        ],
    ]

    input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    # Warm up the model (CUDA)
    inputs_dummy = {k: v.clone() for k, v in inputs.items()}
    with torch.no_grad():
        draft_model(**inputs_dummy)
        target_model(**inputs_dummy)
    torch.cuda.synchronize()

    is_end = False

    # Record
    raw_inputs = copy.deepcopy(inputs)
    raw_token_num = raw_inputs["input_ids"].shape[1]
    start_time = time.time()

    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = args.gamma
    max_new_tokens = args.test_token_num

    while not is_end:
        # Draft model
        target_inputs, draft_probs = drafter_speculative_decode(
            draft_model=draft_model,
            draft_tokenizer=draft_tokenizer,
            inputs=inputs,
            gamma=gamma,
        )

        total_draft_tokens += gamma

        # Target model
        outputs, is_end, accept_tokens = target_speculative_decode(
            target_model=target_model,
            target_tokenizer=target_tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
        )

        total_accept_tokens += accept_tokens

        inputs = outputs

        if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
            break

    generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
    spent_time = time.time() - start_time

    print(f"Generate token number: {generate_token_num}")
    print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
    print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
    print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")

    # Normal Target Model Speed
    raw_inputs = copy.deepcopy(inputs)
    start_time = time.time()
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=target_model,
        draft_tokenizer=draft_tokenizer,
        inputs=raw_inputs,
        gamma=args.test_token_num,
    )

    spent_time = time.time() - start_time

    print(f"Generate token number: {max_new_tokens}")
    print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
    print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")

    # Normal Draft Model Speed
    raw_inputs = copy.deepcopy(inputs)
    start_time = time.time()
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=draft_model,
        draft_tokenizer=draft_tokenizer,
        inputs=raw_inputs,
        gamma=args.test_token_num,
    )

    spent_time = time.time() - start_time

    print(f"Generate token number: {max_new_tokens}")
    print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
    print(f"Normal Draft Model Decoding Spent Time: {spent_time} seconds.\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--target_model_path", type=str, default="HuggingFaceTB/SmolLM2-1.7B-Instruct")
    parser.add_argument("--draft_model_path", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct")
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--question", type=str, default="What is the capital of Taiwan. And why?")
    parser.add_argument("--gamma", type=int, default=5)
    parser.add_argument("--test_token_num", type=int, default=100)
    args = parser.parse_args()

    run_test(args)


Output:

Generate token number: 102
Generate speed: 46.418809914955794 tokens/sec
Speculative Decoding Spent Time: 2.19738507270813 seconds.
Accept Rate: 0.5545454545454546

Generate token number: 100
Generate speed: 27.916420540976226 tokens/sec
Normal Target Model Decoding Spent Time: 3.5821211338043213 seconds.

Generate token number: 100
Generate speed: 96.10154773224576 tokens/sec
Normal Draft Model Decoding Spent Time: 1.0405659675598145 seconds.


直接做表比較的話結果就很明顯:

draft modeltarget modelSpeculative DecodingTotal Acceleration
Original94.92 tokens/sec28.07 tokens/sec34.82 tokens/sec1.24x
Target Threshold96.10 tokens/sec27.92 tokens/sec46.42 tokens/sec1.66x


讓我講我的直覺的話,採用 target model 的信心分數來決定是否啟用草稿推測,是個穩健但天花板可能不高的策略。但對於訓練資源缺乏的我來說,可以穩穩地提昇加速效果是很重要的一件事情。


References


Read More

Leave a Reply