Skip to content

Speculative Decoding Implementation Note (with Simple Experimental Results)

Last Updated on 2024-11-09 by Clay

Introduction

Speculative Decoding is an extremely practical inference acceleration technique that enables a small model (draft model) to rapidly decode multiple tokens and retain the probability distribution of this process. Then, the larger target model, which we aim to accelerate, predicts the next token based on this draft. For each token position, the draft model’s probability distributions are computed and validated using the target model's probabilities, accepting the tokens decoded by the draft model if they are deemed sufficiently reliable.

For a more detailed explanation, you may refer to my previous post summarizing Google's Speculative Decoding paper: [Paper Summary] Fast Inference from Transformers via Speculative Decoding

Translating this into code is fairly straightforward. Suppose the draft model decodes k tokens. We compare the probabilities of each of these tokens with the corresponding probabilities from the target model, resulting in two possible scenarios:

  1. draft model token prob <= target model token prob: Accept the draft model’s decoded token since the target model would predict the same token with at least as high a probability.
  2. draft model token prob > target model token prob: Reject the token with a probability of 1 - (target model token prob / draft model token prob).

Implementation

Below is my design for sampling parameters after calculating the logits from the model. It differs slightly from Hugging Face's approach, but for now, this setup controls the sampling parameters. For a detailed explanation of the sampling parameters, you can refer to my separate implementation post: Notes on Decoding and Sampling for Large Language Models

In this implementation note, I carefully distinguish between logits and probs: logits represent the raw model outputs, theoretically ranging from (-inf, inf), while probs are the probabilities resulting from applying softmax to the logits, confined to the (0, 1) range.

from typing import Tuple

import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def apply_repetition_penalty(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
    batch_size, gamma, vocab_size = logits.shape
    seq_length = prefix_token_ids.shape[1]

    for batch_idx in range(batch_size):
        for gamma_idx in range(gamma):
            current_prefix = prefix_token_ids[batch_idx, :seq_length - gamma + gamma_idx + 1]

            unique_token_ids = set(current_prefix.tolist())

            for token_id in unique_token_ids:
                if logits[batch_idx, gamma_idx, token_id] > 0:
                    logits[batch_idx, gamma_idx, token_id] /= repetition_penalty
                else:
                    logits[batch_idx, gamma_idx, token_id] *= repetition_penalty

    return logits

def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
    if top_k > 0:
        values, _ = torch.topk(logits, top_k, dim=-1)
        min_values = values[:, :, -1].unsqueeze(dim=-1)
        logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)

    return logits


def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Find the position of accumulation probs > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    
    # Get at least one element
    sorted_indices_to_remove[:, :, 1:] = sorted_indices_to_remove[:, :, :-1].clone()
    sorted_indices_to_remove[:, :, 0] = False
    
    # Create the mask that have the same shape of logits
    indices_to_remove = sorted_indices_to_remove.scatter(2, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
    
    return logits


def sample_next_token(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    eps: float = 1e-7,
    probs_num: int = 1,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    curr_logits = logits[:, -probs_num:, :]

    # Apply repetition penalty
    if repetition_penalty != 1.0:
        curr_logits = apply_repetition_penalty(
            logits=curr_logits,
            prefix_token_ids=prefix_token_ids,
            repetition_penalty=repetition_penalty,
        )

    # Apply temperature
    curr_logits = curr_logits / (temperature + eps)

    # Apply `top_k`
    curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)

    # Apply `top_p`
    curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)

    # Convert logits into probs
    probs = torch.softmax(curr_logits, dim=-1)

    # Sampling
    seq_tokens = []
    for seq_idx in range(probs.shape[1]):
        seq_token = torch.multinomial(probs[:, seq_idx, :], num_samples=1)
        seq_tokens.append(seq_token)

    seq_token_ids = torch.cat(seq_tokens, dim=1)

    return seq_token_ids, probs



Below, we dive into the Speculative Decoding implementation itself. Here, I only test up to the validation stage; for now, I assume batch_size=1 due to potential issues with padding tokens to match differing lengths across multiple sequences.

First, let’s import all required libraries.

from typing import Dict, List, Optional, Tuple

import torch
from torch.nn.utils.rnn import pad_sequence

from transformers import LlamaForCausalLM, GPT2TokenizerFast, PreTrainedTokenizerBase

from sampling import sample_next_token



Here, we perform draft decoding of gamma tokens in speculative mode. I concatenate the token IDs and attention mask back to the original inputs, retaining all probability distributions along the way.

def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_probs = []

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

    return inputs, torch.cat(draft_probs, dim=1)



The validation step for the target model is a bit more intricate. While it only runs one forward pass, it needs to sample the probabilities for each of the gamma+1 logits and retain all decoded tokens. During this step, rejected tokens from the draft model are replaced by the target model’s output.

As described earlier, we compare the token probabilities between the target model and the draft model, rejecting tokens with a certain probability when needed.

def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probability 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determine acceptance or rejection
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final rejection masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if not acceptance_mask[batch_idx][pos_idx]:
                    gamma = next_tokens.shape[1] - 1
                    start_idx = inputs["input_ids"].shape[1] - gamma

                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    return inputs



Finally, let’s examine the execution results:

if __name__ == "__main__":
    # Settings
    target_model_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
    draft_model_path = "../models/HuggingFaceTB--SmolLM2-135M-Instruct/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load Tokenizer
    draft_tokenizer = GPT2TokenizerFast.from_pretrained(draft_model_path)
    target_tokenizer = GPT2TokenizerFast.from_pretrained(target_model_path)

    # Load Model
    draft_model = LlamaForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = LlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": "What is the capital of Taiwan. And why?",
            },
        ],
    ]


    input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)


    # Draft model
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=draft_model,
        draft_tokenizer=draft_tokenizer,
        inputs=inputs,
        gamma=10,
    )

    print(target_inputs["input_ids"])
    print("".join(draft_tokenizer.batch_decode(target_inputs["input_ids"][0])))

    # Target model
    outputs = target_speculative_decode(
        target_model=target_model,
        target_tokenizer=target_tokenizer,
        inputs=target_inputs,
        draft_probs=draft_probs,
    )

    print(outputs["input_ids"])
    print("".join(target_tokenizer.batch_decode(outputs["input_ids"][0])))


Output:

tensor([[    1,  9690,   198,  2683,   359,   253,  5356,  5646, 11173,  3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 314, 260, 3575, 282, 15914,
30, 1350, 1701, 47, 2, 198, 1, 520, 9531, 198,
504, 3575, 282, 15914, 314, 12545]], device='cuda:0')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Tai

tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 314, 260, 3575, 282, 15914,
30, 1350, 1701, 47, 2, 198, 1, 520, 9531, 198,
504, 3575, 282, 15914, 314, 12545, 46162]], device='cuda:0')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What is the capital of Taiwan. And why?<|im_end|>
<|im_start|>assistant
The capital of Taiwan is Taipei

This is a positive example. Here, the small model decoded 10 tokens, all of which were accepted, and the large model then inferred the 11th token. A great success.

However, if we look at a more ambiguous question, we might see the large model rejecting the small model's draft more quickly, with a bit of an "I’ll take over" attitude.

if __name__ == "__main__":
    # Settings
    target_model_path = "../models/HuggingFaceTB--SmolLM2-1.7B-Instruct/"
    draft_model_path = "../models/HuggingFaceTB--SmolLM2-135M-Instruct/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load Tokenizer
    draft_tokenizer = GPT2TokenizerFast.from_pretrained(draft_model_path)
    target_tokenizer = GPT2TokenizerFast.from_pretrained(target_model_path)

    # Load Model
    draft_model = LlamaForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = LlamaForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": "What???",
            },
        ],
    ]


    input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)


    # Draft model
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=draft_model,
        draft_tokenizer=draft_tokenizer,
        inputs=inputs,
        gamma=10,
    )

    print(target_inputs["input_ids"])
    print("".join(draft_tokenizer.batch_decode(target_inputs["input_ids"][0])))

    # Target model
    outputs = target_speculative_decode(
        target_model=target_model,
        target_tokenizer=target_tokenizer,
        inputs=target_inputs,
        draft_probs=draft_probs,
    )

    print(outputs["input_ids"])
    print("".join(target_tokenizer.batch_decode(outputs["input_ids"][0])))


Output:

tensor([[    1,  9690,   198,  2683,   359,   253,  5356,  5646, 11173,  3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 16693, 47, 2, 198, 1,
520, 9531, 198, 22234, 8165, 28, 198, 198, 42519]],
device='cuda:0')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What???<|im_end|>
<|im_start|>assistant
Hey Sarah,

Hope

tensor([[ 1, 9690, 198, 2683, 359, 253, 5356, 5646, 11173, 3365,
3511, 308, 34519, 28, 7018, 411, 407, 19712, 8182, 2,
198, 1, 4093, 198, 1780, 16693, 47, 2, 198, 1,
520, 9531, 198, 57]], device='cuda:0')

<|im_start|>system
You are a helpful AI assistant named SmolLM, trained by Hugging Face<|im_end|>
<|im_start|>user
What???<|im_end|>
<|im_start|>assistant
I

As we can see, the large model rejected the first token and regenerated it. However, since the goal is to make the draft model much faster than the target model, even if the first token gets rejected, the time loss should not be significant — at least, that’s the objective.

Over time, I intend to improve several aspects of the speculative decoding framework. Feel free to check out my GitHub if you’re interested: https://github.com/ccs96307/fast-llm-inference


(2024/11/06 Updated)Experimental Results

Initially, I wanted to compare the speed of my Speculative Decoding implementation against direct decoding. As expected, when using the target model alone or the draft model alone, direct decoding was faster than speculative decoding! It turns out that I was directly using the .generate() method from transformers, which is already heavily optimized. So I modified the sampling method for the entire model to use my custom sampling.

This adjustment yielded results that I found quite satisfying.

from typing import Dict, List, Optional, Tuple

import copy
import time

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from sampling import sample_next_token


def calculate_continuous_acceptance(acceptance_mask):
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance


def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_probs = []

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

    return inputs, torch.cat(draft_probs, dim=1)


def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probability 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determine acceptance or rejection
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final rejection masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    return inputs


Output:

Generate token number: 101
Speculative Decoding Spent Time: 2.8416874408721924 seconds.
Accept Rate: 0.40588235294117647


Generate token number: 100
Normal Target Model Decoding Spent Time: 3.5783841609954834 seconds.


Generate token number: 100
Normal Draft Model Decoding Spent Time: 1.0656843185424805 seconds.

The speculative decoding process took around 2.8 seconds to decode approximately 100 tokens, with an acceptance rate of 0.4.

The target model directly decoded 100 tokens in 3.57 seconds, while the draft model took only 1.06 seconds to decode 100 tokens.

Thus, we see that speculative decoding indeed accelerated the target model's decoding speed! Successful experiment.


References


Read More

Leave a Reply