Skip to content

推測性解碼(Speculative Decoding)實作筆記

Last Updated on 2024-11-05 by Clay

介紹

推測性解碼(Speculative Decoding)是一種實用性極強的加速推理技巧,通過讓小模型(draft model)快速、連續地解碼多個 Tokens 並保留過程中的採樣機率分佈,並讓我們真正希望加速的大模型(target model)在此之上預測下一個 Token —— 同時把過往的每個 Token 位置的採樣機率分佈一次性地計算得出,再透過 target model probs 去驗證 draft model probs 的有效性,並接受足夠可靠的 draft model 的推測解碼 Tokens。

比較詳細的原理,或許可以參考我之前閱讀後整理的 Google 團隊發表的 Speculative Decoding 論文:[論文閱讀] Fast Inference from Transformers via Speculative Decoding

換算成程式,其實概念非常簡單。假設 draft model 解碼了 k 個 tokens,我們依序比較這些 tokens 的機率與 target model 解碼出這些 tokens 的機率,並存在兩種狀況:

  1. draft model token prob <= target model token prob:必定接受 draft model 的解碼,因為 target model 只會有更高的機率解碼出此 token
  2. draft model token prob > target model token prob:我們使用 1 - (target model token prob / draft model token prob) 的機率拒絕此 token

實作

以下是我對於模型計算出 logits 後的採樣參數設計,跟 HuggingFace 有些不同,但我暫時是以此來進行採樣的參數控制。詳細的採樣參數說明,可以參考我的另外一篇實作實驗:大型語言模型的解碼採樣筆記

在這篇實作筆記中,我會比較嚴謹地區分 logitsprobs 的定義:logits 是模型原始計算的最後輸出,理論範圍為 (-inf, inf)probs 由於是機率,為 logits 通過 softmax 計算得到的 (0, 1) 範圍機率分佈。

from typing import Tuple

import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def apply_repetition_penalty(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
    batch_size, gamma, vocab_size = logits.shape
    seq_length = prefix_token_ids.shape[1]

    for batch_idx in range(batch_size):
        for gamma_idx in range(gamma):
            current_prefix = prefix_token_ids[batch_idx, :seq_length - gamma + gamma_idx + 1]

            unique_token_ids = set(current_prefix.tolist())

            for token_id in unique_token_ids:
                if logits[batch_idx, gamma_idx, token_id] > 0:
                    logits[batch_idx, gamma_idx, token_id] /= repetition_penalty
                else:
                    logits[batch_idx, gamma_idx, token_id] *= repetition_penalty

    return logits

def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
    if top_k > 0:
        values, _ = torch.topk(logits, top_k, dim=-1)
        min_values = values[:, :, -1].unsqueeze(dim=-1)
        logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)

    return logits


def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Find the position of accumulation probs > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    
    # Get at least one element
    sorted_indices_to_remove[:, :, 1:] = sorted_indices_to_remove[:, :, :-1].clone()
    sorted_indices_to_remove[:, :, 0] = False
    
    # Create the mask that have the same shape of logits
    indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
    
    return logits


def sample_next_token(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    eps: float = 1e-7,
    probs_num: int = 1,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    curr_logits = logits[:, -probs_num:, :]

    # Apply repetition penalty
    if repetition_penalty != 1.0:
        curr_logits = apply_repetition_penalty(
            logits=curr_logits,
            prefix_token_ids=prefix_token_ids,
            repetition_penalty=repetition_penalty,
        )

    # Apply temperature
    curr_logits = curr_logits / (temperature + eps)

    # Apply `top_k`
    curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)

    # Apply `top_p`
    curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)

    # Convert logits into probs
    probs = torch.softmax(curr_logits, dim=-1)

    # Sampling
    seq_tokens = []
    for seq_idx in range(probs.shape[1]):
        seq_token = torch.multinomial(probs[:, seq_idx, :], num_samples=1)
        seq_tokens.append(seq_token)

    seq_token_ids = torch.cat(seq_tokens, dim=1)

    return seq_token_ids, probs



以下,我們正式進入 Speculative Decoding 的實作環節,這裡,我僅僅只測試到驗證環節結束;並且由於 batch_size > 1 的情況會有進度不同的問題(填充成一樣長度可能會造成問題),所以我暫時僅僅只假設 batch_size=1 的情況。

首先,import 我們所有需要使用到的套件。

from typing import Dict, List, Optional, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence

from transformers import LlamaForCausalLM, GPT2TokenizerFast, PreTrainedTokenizerBase

from sampling import sample_next_token



這邊是草稿模型的推測 gamma 個解碼,我同時把 token 和注意力遮罩都拼接回了原本的 inputs,同時保留了過程所有的 probs。

def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_probs = []

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

    return inputs, torch.cat(draft_probs, dim=1)



target model 驗證的部份就比較繁瑣了。雖然它只向前推理一次,但是需要把 gamma + 1 的 logits 全部經過採樣並得到 probs,同時也要保存每個位置的解碼 token,好在拒絕 draft model 的 Token 時能夠直接替換。

之後按照剛才的說明,比較 target model token prob 和 draft model token prob 的大小,並在特定情況以一定機率進行拒絕。

def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if not acceptance_mask[batch_idx][pos_idx]:
                    gamma = next_tokens.shape[1] - 1
                    start_idx = inputs["input_ids"].shape[1] - gamma

                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    return inputs



最後,我們來看看實際執行的結果:

if __name__ == "__main__":
    # Settings
    target_model_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
    draft_model_path = "../models/HuggingFaceTB--SmolLM2-135M-Instruct/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load Tokenizer
    draft_tokenizer = GPT2TokenizerFast.from_pretrained(draft_model_path)
    target_tokenizer = GPT2TokenizerFast.from_pretrained(target_model_path)

    # Load Model
    draft_model = LlamaForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = LlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]


    input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)


    # Draft model
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=draft_model,
        draft_tokenizer=draft_tokenizer,
        inputs=inputs,
        gamma=10,
    )

    print(target_inputs["input_ids"])
    print("".join(draft_tokenizer.batch_decode(target_inputs["input_ids"][0])))

    # Target model
    outputs = target_speculative_decode(
        target_model=target_model,
        target_tokenizer=target_tokenizer,
        inputs=target_inputs,
        draft_probs=draft_probs,
    )

    print(outputs["input_ids"])
    print("".join(target_tokenizer.batch_decode(outputs["input_ids"][0])))


Output:

tensor([[    1,  9690,   198,  2683,   359,   253,  5356,  5646, 11173,  3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 314, 260, 3575, 282, 15914,
30, 1350, 1701, 47, 2, 198, 1, 520, 9531, 198,
504, 3575, 282, 15914, 314, 12545]], device='cuda:0')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Tai

tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 314, 260, 3575, 282, 15914,
30, 1350, 1701, 47, 2, 198, 1, 520, 9531, 198,
504, 3575, 282, 15914, 314, 12545, 46162]], device='cuda:0')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei

這是一個快樂的例子,我們由小模型解碼了 10 個 tokens,並且全部被接受,於是大模型還順勢推理了第 11 個 tokens,簡直太賺了。

但如果我們看一個比較模糊的問題,就會看到大模型很快地拒絕了小模型的推測,擺出一副『我行我上!』的架式。

if __name__ == "__main__":
    # Settings
    target_model_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
    draft_model_path = "../models/HuggingFaceTB--SmolLM2-135M-Instruct/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load Tokenizer
    draft_tokenizer = GPT2TokenizerFast.from_pretrained(draft_model_path)
    target_tokenizer = GPT2TokenizerFast.from_pretrained(target_model_path)

    # Load Model
    draft_model = LlamaForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = LlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": "What???",
            },
        ],
    ]


    input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)


    # Draft model
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=draft_model,
        draft_tokenizer=draft_tokenizer,
        inputs=inputs,
        gamma=10,
    )

    print(target_inputs["input_ids"])
    print("".join(draft_tokenizer.batch_decode(target_inputs["input_ids"][0])))

    # Target model
    outputs = target_speculative_decode(
        target_model=target_model,
        target_tokenizer=target_tokenizer,
        inputs=target_inputs,
        draft_probs=draft_probs,
    )

    print(outputs["input_ids"])
    print("".join(target_tokenizer.batch_decode(outputs["input_ids"][0])))


Output:

tensor([[    1,  9690,   198,  2683,   359,   253,  5356,  5646, 11173,  3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 16693, 47, 2, 198, 1,
520, 9531, 198, 22234, 8165, 28, 198, 198, 42519]],
device='cuda:0')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What???<|im_end|>
<|im_start|>assistant
Hey Sarah,

Hope

tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 16693, 47, 2, 198, 1,
520, 9531, 198, 57]], device='cuda:0')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What???<|im_end|>
<|im_start|>assistant
I

我們可以看到,大模型在第一個 Token 就拒絕了,重新生成了第一個 Token。不過由於我們最終的目標是讓小模型的生成速度是大模型的非常多倍,所以實際上就算第一個 Token 被拒絕,其損失的時間並不算太多 —— 至少,這是我們的目標。

假以時日,我應會慢慢完善加速推理框架的許多實作,感興趣的話可以瀏覽我的 GitHub:https://github.com/ccs96307/fast-llm-inference


References


Read More

Leave a Reply取消回覆

Exit mobile version