Skip to content

Implementation Notes on Integrating Speculative Decoding with KV Cache

Last Updated on 2025-07-01 by Clay

Introduction

Speculative Decoding and KV Cache are both acceleration techniques applicable to Transformer models. The former uses a faster draft model to speculatively generate several subsequent tokens, which are then validated in a batch by the target model to reduce the cost of autoregressive decoding. The latter leverages the causal attention mechanism of Transformers—where past tokens do not attend to future tokens—to cache previously computed results and avoid redundant calculations during inference.

For more details, feel free to check out my previous notes:

These two techniques can be used together. Below is the thought process and implementation steps I took.


KV Cache in Speculative Decoding

Since the transformers library developed by HuggingFace will deprecate the older tuple-style KV Cache format in version 4.47.0, switching instead to the new DynamicCache() class, I adopted DynamicCache() in this implementation. I found that its crop() method makes it quite convenient to discard unneeded KV Cache segments.

This allows us to let the model automatically handle KV Cache updates during inference. We only need to decide whether to truncate certain parts of the KV Cache based on the validation result of Speculative Decoding.

Generally, during Speculative Decoding, we encounter the following four scenarios:

  1. When using KV Cache, we only need to input the last token of the sequence into the model. The input_ids shape changes from (batch_size, seq_len) to (batch_size, 1)
  2. When the target model evaluates gamma speculative tokens generated by the draft model, the input shape becomes (batch_size, gamma+1)
  3. If the target model rejects all but k of the gamma tokens generated by the draft model:
    • The draft model’s accumulated KV Cache of raw_kv_cache_length + gamma needs to be cropped to raw_kv_cache_length + k
    • The KV Cache of raw_kv_cache_length + gamma + 1 needs to be cropped to raw_kv_cache_length + k
  4. When the target model accepts all of the draft model’s speculative tokens, the next speculative decoding step will have one additional token generated by the target model, so the input shape becomes (batch_size, 2)

By handling the above scenarios correctly, we can successfully implement Speculative Decoding with KV Cache.


Implementation Details

First, we need to import all necessary packages. The sample_next_token function is a custom sampling function I implemented. You can refer to this article for details: Decoding Sampling Notes for Large Language Models

from typing import Dict, List, Optional, Tuple, Union

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import argparse
import copy
import time

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, Cache, DynamicCache, PreTrainedTokenizerBase

from sampling.sampling import sample_next_token


"""
python speculative_decoding/run_speculative_decoding.py \
    --target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
    --draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
    --device cuda:0 \
    --question 'What is the capital of Taiwan. And why?' \
    --gamma 5 \
    --test_token_num 100
"""


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance


Next, we implement the sampling step for the draft model. Here, we dynamically reduce the KV Cache or increase the input length based on the current input length (default KV Cache input shape is 1).

def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor, Optional[Union[Cache, List[torch.FloatTensor]]]]:
    draft_probs = []

    for idx in range(gamma):
        raw_inputs_ids = inputs.input_ids

        if isinstance(past_key_values, Cache) and past_key_values.get_seq_length() > 0:
            distance = inputs.input_ids.shape[1] - past_key_values.get_seq_length()

            if distance >= 1:
                inputs.input_ids = inputs.input_ids[:, -distance:]
            else:
                past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
                inputs.input_ids = inputs.input_ids[:, -1:]

        with torch.no_grad():
            outputs = draft_model(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                past_key_values=past_key_values,
                use_cache=past_key_values is not None,
            )

        past_key_values = outputs.past_key_values

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs.input_ids,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([raw_inputs_ids, next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs.attention_mask, torch.ones(inputs.attention_mask.shape[0], 1).to(inputs.input_ids.device)], dim=-1)

        inputs.input_ids = input_ids
        inputs.attention_mask = attention_mask

    return inputs, torch.cat(draft_probs, dim=1), past_key_values


The following part handles the validation step with the target model. As with the draft model, we adjust the KV Cache based on input length. The key difference is that if the target model rejects part of the speculative tokens, its own KV Cache must also be rolled back accordingly.

def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
    past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
) -> Tuple[Dict[str, torch.Tensor], bool, int, Optional[Union[Cache, List[torch.FloatTensor]]]]:
    raw_inputs_ids = inputs.input_ids

    if isinstance(past_key_values, Cache) and past_key_values.get_seq_length() > 0:
        distance = inputs.input_ids.shape[1] - past_key_values.get_seq_length()
        if distance >= 1:
            inputs.input_ids = inputs.input_ids[:, -distance:]
        else:
            past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)
            inputs.input_ids = inputs.input_ids[:, -1:]

    with torch.no_grad():
        outputs = target_model(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            past_key_values=past_key_values,
            use_cache=past_key_values is not None,
        )

    past_key_values = outputs.past_key_values
    inputs.input_ids = raw_inputs_ids

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        diff_probs=draft_probs,
        prefix_token_ids=inputs.input_ids,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs.input_ids[:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)    

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    if torch.all(acceptance_mask):
        inputs.input_ids = torch.cat([inputs.input_ids, next_token], dim=-1)
        inputs.attention_mask = torch.cat([inputs.attention_mask, torch.ones(inputs.attention_mask.shape[0], 1).to(inputs.input_ids.device)], dim=-1)
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs.input_ids.shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs.input_ids[batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs.input_ids[batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]

                    new_input_ids.append(inputs.input_ids[batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs.attention_mask[batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs.input_ids[batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

        inputs.input_ids = input_ids
        inputs.attention_mask = attention_mask

    if isinstance(past_key_values, Cache) and inputs.input_ids.shape[1] <= past_key_values.get_seq_length():
        past_key_values.crop(max_length=inputs.input_ids.shape[1]-1)

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask), past_key_values


Finally, let’s compare the performance of Speculative Decoding with and without KV Cache enabled:

def run_test(args) -> None:
    # Device
    device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # Model path 
    target_model_path = args.target_model_path
    draft_model_path = args.draft_model_path

    # Load Tokenizer
    draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
    target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)

    # Load Model
    draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": args.question,
            },
        ],
    ]

    input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    # Warm up the model (CUDA)
    inputs_dummy = {k: v.clone() for k, v in inputs.items()}
    with torch.no_grad():
        draft_model(**inputs_dummy)
        target_model(**inputs_dummy)
    torch.cuda.synchronize()

    is_end = False

    # Record
    raw_inputs = copy.deepcopy(inputs)
    raw_token_num = raw_inputs.input_ids.shape[1]
    start_time = time.time()

    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = args.gamma
    max_new_tokens = args.test_token_num

    draft_past_key_values = None
    target_past_key_values = None

    while not is_end:
        # Draft model
        target_inputs, draft_probs, draft_past_key_values = drafter_speculative_decode(
            draft_model=draft_model,
            draft_tokenizer=draft_tokenizer,
            inputs=inputs,
            gamma=gamma,
            temperature=0,
            past_key_values=draft_past_key_values,
        )

        total_draft_tokens += gamma

        # Target model
        outputs, is_end, accept_tokens, target_past_key_values = target_speculative_decode(
            target_model=target_model,
            target_tokenizer=target_tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
            temperature=0,
            past_key_values=target_past_key_values,
        )

        total_accept_tokens += accept_tokens

        inputs = outputs

        if inputs.input_ids.shape[1] - raw_token_num >= max_new_tokens:
            break

    generate_token_num = outputs.input_ids.shape[1] - raw_token_num
    spent_time = time.time() - start_time

    print(f"(Without KV Cache) Generate token number: {generate_token_num}")
    print(f"(Without KV Cache) Generate speed: {generate_token_num / spent_time} tokens/sec")
    print(f"(Without KV Cache) Speculative Decoding Spent Time: {spent_time} seconds.")
    print(f"(Without KV Cache) Accept Rate: {total_accept_tokens / total_draft_tokens}\n")


    # KV Cache Speculative Decoding
    is_end = False

    # Record
    inputs = copy.deepcopy(raw_inputs)
    raw_token_num = inputs.input_ids.shape[1]
    start_time = time.time()

    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = args.gamma
    max_new_tokens = args.test_token_num

    draft_past_key_values = DynamicCache()
    target_past_key_values = DynamicCache()

    while not is_end:
        # Draft model
        target_inputs, draft_probs, draft_past_key_values = drafter_speculative_decode(
            draft_model=draft_model,
            draft_tokenizer=draft_tokenizer,
            inputs=inputs,
            gamma=gamma,
            temperature=0,
            past_key_values=draft_past_key_values,
        )

        total_draft_tokens += gamma

        # Target model
        outputs, is_end, accept_tokens, target_past_key_values = target_speculative_decode(
            target_model=target_model,
            target_tokenizer=target_tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
            temperature=0,
            past_key_values=target_past_key_values,
        )

        total_accept_tokens += accept_tokens

        inputs = outputs

        if inputs.input_ids.shape[1] - raw_token_num >= max_new_tokens:
            break

    generate_token_num = outputs.input_ids.shape[1] - raw_token_num
    spent_time = time.time() - start_time

    print(f"(KV Cache) Generate token number: {generate_token_num}")
    print(f"(KV Cache) Generate speed: {generate_token_num / spent_time} tokens/sec")
    print(f"(KV Cache) Speculative Decoding Spent Time: {spent_time} seconds.")
    print(f"(KV Cache) Accept Rate: {total_accept_tokens / total_draft_tokens}\n")


Output:

(Without KV Cache) Generate token number: 101
(Without KV Cache) Generate speed: 51.32459081142281 tokens/sec
(Without KV Cache) Speculative Decoding Spent Time: 1.9678676128387451 seconds.
(Without KV Cache) Accept Rate: 0.7719298245614035

(KV Cache) Generate token number: 101
(KV Cache) Generate speed: 62.468003457069095 tokens/sec
(KV Cache) Speculative Decoding Spent Time: 1.6168277263641357 seconds.
(KV Cache) Accept Rate: 0.8035714285714286

As we can see, performance has indeed improved.


References


Read More

Leave a ReplyCancel reply

Exit mobile version