Skip to content

Using The Target Model's Confidence Threshold To Decide Whether To Enable Speculative Decoding

Last Updated on 2024-11-22 by Clay

Many of the inference acceleration techniques I have studied, such as Speculative Decoding, predominantly use a threshold for the confidence scores of the draft model. This threshold determines how many draft tokens should be decoded before passing them to the target model for verification, thereby reducing the extra computational cost when the draft model operates with low confidence.

Although this method indeed achieves dynamic performance optimization, I had a lingering concern when I first encountered it: If the draft model is inherently less capable—especially in scenarios where its understanding is subpar—should I still trust its high confidence scores?

In other words, would it not make more sense to base the decision to activate speculative decoding on the confidence score of the target model? Since the weaker draft model's high confidence might not be reliable (as the target model ultimately validates the results), the target model's low confidence should provide sufficient justification to avoid delegating decoding tasks to the draft model.

To put it more plainly: The draft model's high confidence may simply be unwarranted self-assurance, and this alone shouldn't justify allowing it to continue generating tokens. Conversely, if the target model exhibits low confidence, it signals uncertainty, making it unreasonable to rely on the weaker draft model for predictions.

In conclusion, perhaps I lack sufficient exposure or have not explored enough studies, but I have yet to see a paper debating whether using the draft model's confidence score for speculative decoding is better than relying on the target model's confidence score to enable the speculative mode. That said, in my initial tests, the acceleration effect was noticeable, so I am documenting this as a reference for my research on inference acceleration techniques.


Background Knowledge

First, feel free to check out my GitHub, where I upload many of my implementations to the fast-llm-inference repository.

This note is a follow-up to my previous post: Speculative Decoding Implementation Notes (with Basic Experimental Results).

In the previous post, I documented my implementation of sampling-based Speculative Decoding, along with a simple experiment. Here, I will refine that code and compare the acceleration results by incorporating the target model's confidence score to decide whether to enable speculative decoding.

After running several experiments with the original code, I obtained the following results:

Generate token number: 100
Generate speed: 34.81801971715937 tokens/sec
Speculative Decoding Spent Time: 2.8720760345458984 seconds.
Accept Rate: 0.34054054054054056

Generate token number: 100
Generate speed: 28.07058497562908 tokens/sec
Normal Target Model Decoding Spent Time: 3.562448024749756 seconds.

Generate token number: 100
Generate speed: 94.92253307563831 tokens/sec
Normal Draft Model Decoding Spent Time: 1.0534906387329102 seconds.


In summary:

  • Draft model decoding speed: 94.92 tokens/sec
  • Target model decoding speed: 28.07 tokens/sec
  • Speculative Decoding speed: 34.82 tokens/sec

This achieved approximately 1.24x acceleration.


Using Target Model Confidence to Decide Whether to Enable Speculative Decoding

Motivated by the idea that if the current decoding task is uncertain, I should not burden the weaker draft model, I made the following changes:

def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    confidence_score = 0

    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
        confidence_score = target_probs[:, -1, next_token[0][0]].item()
        print(f"Confidence for next token: {confidence_score:.4f}")
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
                    confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
                    print(f"Replacement Confidence for next token: {confidence_score:.4f}")

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    # Keep generating if confidence_score is less than confidence threshold
    while confidence_score < 0.5:
        with torch.no_grad():
            outputs = target_model(**inputs)

        next_tokens, target_probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            probs_num=1,
        )

        # Update `confidence_score`
        next_token = next_tokens[:, -1:]
        confidence_score = target_probs[0, -1, next_token[0][0]].item()
        print(f"keep generate confidence_score: {confidence_score:.4f}")

        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

        is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
        if is_end:
            break

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


Compared to the original code, I repeatedly retrieved the target model's confidence score during verification. Given that the probability distribution is already obtained post-sampling, this added virtually no additional overhead in my implementation.


Below are the complete source code and experimental setup:

  • GPU: RTX 4060 8GB
  • OS: Ubuntu 22.04
  • Target model: HuggingFaceTB/SmolLM2-1.7B-Instruct
  • Draft model: HuggingFaceTB/SmolLM2-135M-Instruct
from typing import Dict, List, Optional, Tuple

import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))

import argparse
import copy
import time

import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase

from sampling.sampling import sample_next_token


"""
python speculative_decoding/run_speculative_decoding.py \
    --target_model_path HuggingFaceTB/SmolLM2-1.7B-Instruct \
    --draft_model_path HuggingFaceTB/SmolLM2-135M-Instruct \
    --device cuda:0 \
    --question 'What is the capital of Taiwan. And why?' \
    --gamma 5 \
    --test_token_num 100 
"""


def calculate_continuous_acceptance(acceptance_mask: torch.BoolTensor) -> int:
    continuous_acceptance = 0
    for accepted in acceptance_mask.long().squeeze(0):
        if accepted == 1:
            continuous_acceptance += 1
        else:
            break
    return continuous_acceptance


def drafter_speculative_decode(
    draft_model: torch.nn.Module,
    draft_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    gamma: int = 10,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], torch.FloatTensor]:
    draft_probs = []

    for idx in range(gamma):
        with torch.no_grad():
            outputs = draft_model(**inputs)

        next_tokens, probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )

        draft_probs.append(probs)
        input_ids = torch.cat([inputs["input_ids"], next_tokens[:, -1:]], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

    return inputs, torch.cat(draft_probs, dim=1)


def target_speculative_decode(
    target_model: torch.nn.Module,
    target_tokenizer: PreTrainedTokenizerBase,
    inputs: Dict[str, torch.Tensor],
    draft_probs: torch.FloatTensor,
    temperature: float = 1.0,
    top_k: Optional[int] = 0,  # Default is 0, it means do not select top-k tokens
    top_p: Optional[float] = 1.0,
    repetition_penalty: Optional[float] = 1.0,
) -> Tuple[Dict[str, torch.Tensor], bool, int]:
    with torch.no_grad():
        outputs = target_model(**inputs)

    next_tokens, target_probs = sample_next_token(
        logits=outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        probs_num=draft_probs.shape[1] + 1,
    )

    next_token = next_tokens[:, -1:]

    # Evaluation
    indices = inputs["input_ids"][:, -draft_probs.shape[1]:]

    eval_probs = target_probs[:, :-1, :]

    expanded_indices = indices.unsqueeze(-1)
    selected_draft_probs = torch.gather(draft_probs, dim=-1, index=expanded_indices)
    selected_draft_probs = selected_draft_probs.squeeze(-1)

    selected_eval_probs = torch.gather(eval_probs, dim=-1, index=expanded_indices)
    selected_eval_probs = selected_eval_probs.squeeze(-1)

    # Compare draft_prob and eval_prob, and check the reject_mask
    mask_to_reject = selected_draft_probs > selected_eval_probs

    # Calculate reject probabilty 1 - (eval_prob / draft_prob)
    rejection_probs = 1 - (selected_eval_probs / selected_draft_probs)

    # Generate random values to determined accept or reject
    random_values = torch.rand_like(rejection_probs)
    rejection_decisions = random_values < rejection_probs

    # Get the final reject masks
    rejection_masks = mask_to_reject & rejection_decisions
    acceptance_mask = torch.ones_like(selected_draft_probs, dtype=torch.bool)
    acceptance_mask[rejection_masks] = False

    is_end = False

    # Concat `input_ids`
    confidence_score = 0

    if torch.all(acceptance_mask):
        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)
        confidence_score = target_probs[:, -1, next_token[0][0]].item()
        print(f"Confidence for next token: {confidence_score:.4f}")
    else:
        new_input_ids = []
        new_attention_mask = []

        for batch_idx in range(next_tokens.shape[0]):
            gamma = next_tokens.shape[1] - 1
            start_idx = inputs["input_ids"].shape[1] - gamma

            for pos_idx in range(acceptance_mask[batch_idx].shape[0]):
                if (acceptance_mask[batch_idx][pos_idx] and inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id) or not acceptance_mask[batch_idx][pos_idx]:
                    inputs["input_ids"][batch_idx][start_idx+pos_idx] = next_tokens[batch_idx][pos_idx]
                    confidence_score = target_probs[batch_idx, pos_idx, next_tokens[batch_idx, pos_idx]].max().item()
                    print(f"Replacement Confidence for next token: {confidence_score:.4f}")

                    new_input_ids.append(inputs["input_ids"][batch_idx][:start_idx+pos_idx+1])
                    new_attention_mask.append(inputs["attention_mask"][batch_idx][:start_idx+pos_idx+1])
                    
                    is_end = inputs["input_ids"][batch_idx][start_idx+pos_idx].item() == target_tokenizer.eos_token_id
                    break

        input_ids = pad_sequence(new_input_ids, batch_first=True, padding_value=target_tokenizer.pad_token_id)
        attention_mask = pad_sequence(new_attention_mask, batch_first=True, padding_value=0)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask

    # Keep generating if confidence_score is less than confidence threshold
    while confidence_score < 0.5:
        with torch.no_grad():
            outputs = target_model(**inputs)

        next_tokens, target_probs = sample_next_token(
            logits=outputs.logits,
            prefix_token_ids=inputs["input_ids"],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            probs_num=1,
        )

        # Update `confidence_score`
        next_token = next_tokens[:, -1:]
        confidence_score = target_probs[0, -1, next_token[0][0]].item()
        print(f"keep generate confidence_score: {confidence_score:.4f}")

        input_ids = torch.cat([inputs["input_ids"], next_token], dim=-1)
        attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device)], dim=-1)

        inputs["input_ids"] = input_ids
        inputs["attention_mask"] = attention_mask

        is_end = inputs["input_ids"][0][-1].item() == target_tokenizer.eos_token_id
        if is_end:
            break

    return inputs, is_end, calculate_continuous_acceptance(acceptance_mask)


def run_test(args) -> None:
    # Device
    device = torch.device(args.device if args.device != "cpu" and torch.cuda.is_available() else "cpu")
    print(device)

    # Model path 
    target_model_path = args.target_model_path
    draft_model_path = args.draft_model_path

    # Load Tokenizer
    draft_tokenizer = AutoTokenizer.from_pretrained(draft_model_path)
    target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)

    # Load Model
    draft_model = AutoModelForCausalLM.from_pretrained(draft_model_path, torch_dtype=torch.bfloat16).to(device)
    target_model = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.bfloat16).to(device)

    # Tokenize
    messages = [
        [
            {
                "role": "user",
                "content": args.question,
            },
        ],
    ]

    input_text=draft_tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = draft_tokenizer(
        input_text,
        return_tensors="pt",
        max_length=512,
        truncation=True,
        padding=True,
    ).to(device)

    # Warm up the model (CUDA)
    inputs_dummy = {k: v.clone() for k, v in inputs.items()}
    with torch.no_grad():
        draft_model(**inputs_dummy)
        target_model(**inputs_dummy)
    torch.cuda.synchronize()

    is_end = False

    # Record
    raw_inputs = copy.deepcopy(inputs)
    raw_token_num = raw_inputs["input_ids"].shape[1]
    start_time = time.time()

    total_draft_tokens = 0
    total_accept_tokens = 0
    gamma = args.gamma
    max_new_tokens = args.test_token_num

    while not is_end:
        # Draft model
        target_inputs, draft_probs = drafter_speculative_decode(
            draft_model=draft_model,
            draft_tokenizer=draft_tokenizer,
            inputs=inputs,
            gamma=gamma,
        )

        total_draft_tokens += gamma

        # Target model
        outputs, is_end, accept_tokens = target_speculative_decode(
            target_model=target_model,
            target_tokenizer=target_tokenizer,
            inputs=target_inputs,
            draft_probs=draft_probs,
        )

        total_accept_tokens += accept_tokens

        inputs = outputs

        if inputs["input_ids"].shape[1] - raw_token_num >= max_new_tokens:
            break

    generate_token_num = outputs["input_ids"].shape[1] - raw_token_num
    spent_time = time.time() - start_time

    print(f"Generate token number: {generate_token_num}")
    print(f"Generate speed: {generate_token_num / spent_time} tokens/sec")
    print(f"Speculative Decoding Spent Time: {spent_time} seconds.")
    print(f"Accept Rate: {total_accept_tokens / total_draft_tokens}\n")

    # Normal Target Model Speed
    raw_inputs = copy.deepcopy(inputs)
    start_time = time.time()
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=target_model,
        draft_tokenizer=draft_tokenizer,
        inputs=raw_inputs,
        gamma=args.test_token_num,
    )

    spent_time = time.time() - start_time

    print(f"Generate token number: {max_new_tokens}")
    print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
    print(f"Normal Target Model Decoding Spent Time: {spent_time} seconds.\n")

    # Normal Draft Model Speed
    raw_inputs = copy.deepcopy(inputs)
    start_time = time.time()
    target_inputs, draft_probs = drafter_speculative_decode(
        draft_model=draft_model,
        draft_tokenizer=draft_tokenizer,
        inputs=raw_inputs,
        gamma=args.test_token_num,
    )

    spent_time = time.time() - start_time

    print(f"Generate token number: {max_new_tokens}")
    print(f"Generate speed: {max_new_tokens / spent_time} tokens/sec")
    print(f"Normal Draft Model Decoding Spent Time: {spent_time} seconds.\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--target_model_path", type=str, default="HuggingFaceTB/SmolLM2-1.7B-Instruct")
    parser.add_argument("--draft_model_path", type=str, default="HuggingFaceTB/SmolLM2-135M-Instruct")
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--question", type=str, default="What is the capital of Taiwan. And why?")
    parser.add_argument("--gamma", type=int, default=5)
    parser.add_argument("--test_token_num", type=int, default=100)
    args = parser.parse_args()

    run_test(args)


Output:

Generate token number: 102
Generate speed: 46.418809914955794 tokens/sec
Speculative Decoding Spent Time: 2.19738507270813 seconds.
Accept Rate: 0.5545454545454546

Generate token number: 100
Generate speed: 27.916420540976226 tokens/sec
Normal Target Model Decoding Spent Time: 3.5821211338043213 seconds.

Generate token number: 100
Generate speed: 96.10154773224576 tokens/sec
Normal Draft Model Decoding Spent Time: 1.0405659675598145 seconds.


To present the results more clearly:

Draft ModelTarget ModelSpeculative DecodingTotal Acceleration
Original94.92 tokens/sec28.07 tokens/sec34.82 tokens/sec1.24x
Target Threshold96.10 tokens/sec27.92 tokens/sec46.42 tokens/sec1.66x


In my intuition, relying on the target model's confidence to decide whether to enable speculative decoding is a robust yet potentially limited strategy. However, for someone like me with limited training resources, achieving steady acceleration improvements is crucial.


References


Read More

Leave a Reply