Skip to content

A Note Of Large Language Model Decode Sampling

Last Updated on 2024-11-08 by Clay

When we use large language models for generative tasks, particularly in auto-regressive tasks, the model essentially performs a massive classification task. The classification targets are the tokens in our vocabulary, which are the smallest building blocks that make up words.

If we want to use greedy decoding, we can simply take the maximum value of the logits in the final layer of the model's decoding layer. However, if we want to introduce diversity and some level of randomness in the model's output, we have several parameters we can adjust to turn the logits into a probability distribution.

This note is not a standard implementation like those found in popular frameworks such as HuggingFace. It is merely an experimental approach I used when developing a framework to accelerate inference, so feel free to reference the concepts rather than the exact code.


Sampling Parameters

There are many types of sampling parameters, but this article only discusses the most common ones:

  • Repetition Penalty: reduces the probability of tokens that have previously appeared in the sequence.
  • Temperature: scales the difference between logits for different tokens, amplifying or dampening distinctions between them.
  • Top-K: selects the top-K tokens for decoding candidates.
  • Top-P: selects tokens that cumulatively account for a probability of Top-P as decoding candidates.

In the following, we assume the logits to be decoded have the shape (batch_size, vocab_size). I have hidden the sequence length (seq_length) because we always apply decoding and probability adjustments to the final layer.

Let's look at the decoding process in order:

def sample_next_token(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    eps: float = 1e-7,
) -> torch.FloatTensor:
    curr_logits = logits[:, -1, :]

    # Apply repetition penalty
    if repetition_penalty != 1.0:
        curr_logits = apply_repetition_penalty(
            logits=curr_logits,
            prefix_token_ids=prefix_token_ids,
            repetition_penalty=repetition_penalty,
        )

    # Apply temperature
    curr_logits = curr_logits / (temperature + eps)

    # Apply `top_k`
    curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)

    # Apply `top_p`
    curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)

    # Convert logits into probs
    probs = torch.softmax(curr_logits, dim=-1)

    # Sampling
    next_token = torch.multinomial(probs, num_samples=1)

    return next_token


When the logits from the model's last layer enter the sampling stage, the original shape should be (batch_size, seq_length, vocab_size). Here, I fix it to the last layer of seq_length.

Following the sequence of operations, we apply repetition penalty, temperature, top-k sampling, top-p sampling, then convert the results to a probability distribution via Softmax. Finally, torch.multinomial selects the next decoding token from the distribution.

The repetition penalty method adjusts the current decoding logits based on tokens that appeared earlier in the sentence, applying a multiplicative or divisive factor of repetition_penalty to the corresponding token positions to reduce their probability of reappearing.

def apply_repetition_penalty(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
    for batch_idx in range(prefix_token_ids.shape[0]):
        for token_id in set(prefix_token_ids[batch_idx].tolist()):
            if logits[batch_idx, token_id] < 0:
                logits[batch_idx, token_id] *= repetition_penalty
            else:
                logits[batch_idx, token_id] /= repetition_penalty
    return logits


Temperature adjustment is straightforward; we divide the logits by the temperature. If the temperature is below 1.0, it amplifies the differences between tokens, making higher probability tokens more likely to be chosen. Conversely, a temperature above 1.0 reduces the differences, giving less probable tokens a higher chance of being sampled.

Note that using 0 for the temperature parameter as a divisor would cause an error, so I add a small eps value (defaulted to 1e-7) to prevent this.

# Apply temperature
curr_logits = curr_logits / (temperature + eps)



Top-K sampling is also quite simple. I find the smallest value in the top-K values of logits in the final layer and use torch.where() to set all positions with values below this top-K minimum to -float("Inf").

def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
    if top_k > 0:
        values, indices = torch.topk(logits, top_k, dim=-1)
        min_values = values[:, -1].unsqueeze(dim=-1)
        logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)

    return logits



Top-P sampling is relatively more complex. In essence, it sequentially includes tokens with the highest probabilities until the cumulative probability reaches the threshold set by top_p. It's a method that focuses on the top few percent of the probability distribution.

For example, suppose we have a probability distribution [0.4, 0.2, 0.15, 0.15, 0.1], which sums to 1. If we set top_p to 0.8, our sampling process would proceed as follows:

  • Take 0.4; cumulative probability is now 0.4
  • Take 0.2; cumulative probability is now 0.6
  • Take 0.15; cumulative probability is now 0.75
  • Take 0.15; cumulative probability is now 0.9 — hold on, we’ve surpassed 0.8, so this element isn’t selected, leaving [0.4, 0.2, 0.15] as the remaining elements for sampling.

In implementation, we first sort logits from largest to smallest, while retaining the original indices for masking purposes later. After performing softmax(logits) and summing the probabilities, we set all tokens where cumulative probability exceeds top_p to -float("Inf").

def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Find the position of accumulation probs > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    
    # Get at least one element
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = False
    
    # Create the mask that have the same shape of logits
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
    
    return logits



Results Comparison

Below is a comparison using fixed hyperparameters for sampling, between my implementation and HuggingFace’s implementation (of course, HuggingFace may have additional sampling rules, so the results aren’t identical). The model used here is GPT-2, with the following sampling parameters:

  • temperature = 0.1
  • top_k = 50
  • top_p = 0.9
  • repetition_penalty = 1.2
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


# Settings
pretrained_model_name_or_path = "openai-community/gpt2"

# Model & Tokenizer
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"


def apply_repetition_penalty(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
    for batch_idx in range(prefix_token_ids.shape[0]):
        for token_id in set(prefix_token_ids[batch_idx].tolist()):
            if logits[batch_idx, token_id] < 0:
                logits[batch_idx, token_id] *= repetition_penalty
            else:
                logits[batch_idx, token_id] /= repetition_penalty
    return logits

def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
    if top_k > 0:
        values, indices = torch.topk(logits, top_k, dim=-1)
        min_values = values[:, -1].unsqueeze(dim=-1)
        logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)

    return logits


def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
    
    # Find the position of accumulation probs > top_p
    sorted_indices_to_remove = cumulative_probs > top_p
    
    # Get at least one element
    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
    sorted_indices_to_remove[:, 0] = False
    
    # Create the mask that have the same shape of logits
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits = logits.masked_fill(indices_to_remove, -float("Inf"))
    
    return logits


def sample_next_token(
    logits: torch.FloatTensor,
    prefix_token_ids: torch.LongTensor,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    eps: float = 1e-7,
) -> torch.FloatTensor:
    curr_logits = logits[:, -1, :]

    # Apply repetition penalty
    if repetition_penalty != 1.0:
        curr_logits = apply_repetition_penalty(
            logits=curr_logits,
            prefix_token_ids=prefix_token_ids,
            repetition_penalty=repetition_penalty,
        )

    # Apply temperature
    curr_logits = curr_logits / (temperature + eps)

    # Apply `top_k`
    curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)

    # Apply `top_p`
    curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)

    # Convert logits into probs
    probs = torch.softmax(curr_logits, dim=-1)

    # Sampling
    next_token = torch.multinomial(probs, num_samples=1)

    return next_token


temperature = 0.1
top_k = 50
top_p = 0.9
repetition_penalty = 1.2


# Test data
sentences = [
    "Today is a nice day",
    "How are you?",
]

inputs = tokenizer(
    sentences,
    max_length=512,
    truncation=True,
    padding=True,
    return_tensors="pt",
).to("cuda:0")


print("=== My Sampling ===")
for idx in range(10):
    outputs = model(**inputs)

    next_tokens = sample_next_token(
        outputs.logits,
        prefix_token_ids=inputs["input_ids"],
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
    )

    input_ids = torch.cat([inputs["input_ids"], next_tokens], dim=-1)
    attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)

    inputs["input_ids"] = input_ids
    inputs["attention_mask"] = attention_mask


for sent in tokenizer.batch_decode(inputs.input_ids):
    print(sent)

print("\n=== HuggingFace ===")


# Test data
sentences = [
    "Today is a nice day",
    "How are you?",
]

inputs = tokenizer(
    sentences,
    max_length=512,
    truncation=True,
    padding=True,
    return_tensors="pt",
).to("cuda:0")


outputs = model.generate(
    **inputs,
    temperature=temperature,
    top_k=top_k,
    top_p=top_p,
    repetition_penalty=repetition_penalty,
)


for sent in tokenizer.batch_decode(outputs):
    print(sent)


Output:

=== My Sampling ===
Today is a nice day for the world to celebrate our country's independence.
How are you?<|endoftext|>The answer is simple. You're a young man

=== HuggingFace ===
Today is a nice day for the world to celebrate our country's independence.
<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
How are you?<|endoftext|>The answer is simple: You're not a human being. Your brain works


As we can see, the outputs from both implementations are quite similar.


References


Read More

Leave a Reply