Last Updated on 2024-11-08 by Clay
When we use large language models for generative tasks, particularly in auto-regressive tasks, the model essentially performs a massive classification task. The classification targets are the tokens in our vocabulary, which are the smallest building blocks that make up words.
If we want to use greedy decoding, we can simply take the maximum value of the logits in the final layer of the model's decoding layer. However, if we want to introduce diversity and some level of randomness in the model's output, we have several parameters we can adjust to turn the logits into a probability distribution.
This note is not a standard implementation like those found in popular frameworks such as HuggingFace. It is merely an experimental approach I used when developing a framework to accelerate inference, so feel free to reference the concepts rather than the exact code.
Sampling Parameters
There are many types of sampling parameters, but this article only discusses the most common ones:
- Repetition Penalty: reduces the probability of tokens that have previously appeared in the sequence.
- Temperature: scales the difference between logits for different tokens, amplifying or dampening distinctions between them.
- Top-K: selects the top-K tokens for decoding candidates.
- Top-P: selects tokens that cumulatively account for a probability of Top-P as decoding candidates.
In the following, we assume the logits to be decoded have the shape (batch_size, vocab_size)
. I have hidden the sequence length (seq_length) because we always apply decoding and probability adjustments to the final layer.
Let's look at the decoding process in order:
def sample_next_token(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
eps: float = 1e-7,
) -> torch.FloatTensor:
curr_logits = logits[:, -1, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
curr_logits = apply_repetition_penalty(
logits=curr_logits,
prefix_token_ids=prefix_token_ids,
repetition_penalty=repetition_penalty,
)
# Apply temperature
curr_logits = curr_logits / (temperature + eps)
# Apply `top_k`
curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)
# Apply `top_p`
curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)
# Convert logits into probs
probs = torch.softmax(curr_logits, dim=-1)
# Sampling
next_token = torch.multinomial(probs, num_samples=1)
return next_token
When the logits from the model's last layer enter the sampling stage, the original shape should be (batch_size, seq_length, vocab_size)
. Here, I fix it to the last layer of seq_length
.
Following the sequence of operations, we apply repetition penalty, temperature, top-k sampling, top-p sampling, then convert the results to a probability distribution via Softmax. Finally, torch.multinomial
selects the next decoding token from the distribution.
The repetition penalty method adjusts the current decoding logits based on tokens that appeared earlier in the sentence, applying a multiplicative or divisive factor of repetition_penalty
to the corresponding token positions to reduce their probability of reappearing.
def apply_repetition_penalty(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
for batch_idx in range(prefix_token_ids.shape[0]):
for token_id in set(prefix_token_ids[batch_idx].tolist()):
if logits[batch_idx, token_id] < 0:
logits[batch_idx, token_id] *= repetition_penalty
else:
logits[batch_idx, token_id] /= repetition_penalty
return logits
Temperature adjustment is straightforward; we divide the logits by the temperature. If the temperature is below 1.0, it amplifies the differences between tokens, making higher probability tokens more likely to be chosen. Conversely, a temperature above 1.0 reduces the differences, giving less probable tokens a higher chance of being sampled.
Note that using 0 for the temperature parameter as a divisor would cause an error, so I add a small eps
value (defaulted to 1e-7
) to prevent this.
# Apply temperature
curr_logits = curr_logits / (temperature + eps)
Top-K sampling is also quite simple. I find the smallest value in the top-K values of logits
in the final layer and use torch.where()
to set all positions with values below this top-K minimum to -float("Inf")
.
def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
if top_k > 0:
values, indices = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(dim=-1)
logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)
return logits
Top-P sampling is relatively more complex. In essence, it sequentially includes tokens with the highest probabilities until the cumulative probability reaches the threshold set by top_p
. It's a method that focuses on the top few percent of the probability distribution.
For example, suppose we have a probability distribution [0.4, 0.2, 0.15, 0.15, 0.1], which sums to 1. If we set top_p
to 0.8, our sampling process would proceed as follows:
- Take 0.4; cumulative probability is now 0.4
- Take 0.2; cumulative probability is now 0.6
- Take 0.15; cumulative probability is now 0.75
- Take 0.15; cumulative probability is now 0.9 — hold on, we’ve surpassed 0.8, so this element isn’t selected, leaving
[0.4, 0.2, 0.15]
as the remaining elements for sampling.
In implementation, we first sort logits
from largest to smallest, while retaining the original indices for masking purposes later. After performing softmax(logits)
and summing the probabilities, we set all tokens where cumulative probability exceeds top_p
to -float("Inf")
.
def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Find the position of accumulation probs > top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Get at least one element
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
# Create the mask that have the same shape of logits
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
return logits
Results Comparison
Below is a comparison using fixed hyperparameters for sampling, between my implementation and HuggingFace’s implementation (of course, HuggingFace may have additional sampling rules, so the results aren’t identical). The model used here is GPT-2, with the following sampling parameters:
- temperature = 0.1
- top_k = 50
- top_p = 0.9
- repetition_penalty = 1.2
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Settings
pretrained_model_name_or_path = "openai-community/gpt2"
# Model & Tokenizer
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
def apply_repetition_penalty(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
repetition_penalty: float = 1.0,
) -> torch.FloatTensor:
for batch_idx in range(prefix_token_ids.shape[0]):
for token_id in set(prefix_token_ids[batch_idx].tolist()):
if logits[batch_idx, token_id] < 0:
logits[batch_idx, token_id] *= repetition_penalty
else:
logits[batch_idx, token_id] /= repetition_penalty
return logits
def top_k_filtering(logits: torch.FloatTensor, top_k: int) -> torch.FloatTensor:
if top_k > 0:
values, indices = torch.topk(logits, top_k, dim=-1)
min_values = values[:, -1].unsqueeze(dim=-1)
logits = torch.where(logits < min_values, torch.full_like(logits, -float("Inf")), logits)
return logits
def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Find the position of accumulation probs > top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Get at least one element
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
# Create the mask that have the same shape of logits
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
return logits
def sample_next_token(
logits: torch.FloatTensor,
prefix_token_ids: torch.LongTensor,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
eps: float = 1e-7,
) -> torch.FloatTensor:
curr_logits = logits[:, -1, :]
# Apply repetition penalty
if repetition_penalty != 1.0:
curr_logits = apply_repetition_penalty(
logits=curr_logits,
prefix_token_ids=prefix_token_ids,
repetition_penalty=repetition_penalty,
)
# Apply temperature
curr_logits = curr_logits / (temperature + eps)
# Apply `top_k`
curr_logits = top_k_filtering(logits=curr_logits, top_k=top_k)
# Apply `top_p`
curr_logits = top_p_filtering(logits=curr_logits, top_p=top_p)
# Convert logits into probs
probs = torch.softmax(curr_logits, dim=-1)
# Sampling
next_token = torch.multinomial(probs, num_samples=1)
return next_token
temperature = 0.1
top_k = 50
top_p = 0.9
repetition_penalty = 1.2
# Test data
sentences = [
"Today is a nice day",
"How are you?",
]
inputs = tokenizer(
sentences,
max_length=512,
truncation=True,
padding=True,
return_tensors="pt",
).to("cuda:0")
print("=== My Sampling ===")
for idx in range(10):
outputs = model(**inputs)
next_tokens = sample_next_token(
outputs.logits,
prefix_token_ids=inputs["input_ids"],
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
input_ids = torch.cat([inputs["input_ids"], next_tokens], dim=-1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones(inputs["attention_mask"].shape[0], 1).to(inputs["input_ids"].device.type)], dim=-1)
inputs["input_ids"] = input_ids
inputs["attention_mask"] = attention_mask
for sent in tokenizer.batch_decode(inputs.input_ids):
print(sent)
print("\n=== HuggingFace ===")
# Test data
sentences = [
"Today is a nice day",
"How are you?",
]
inputs = tokenizer(
sentences,
max_length=512,
truncation=True,
padding=True,
return_tensors="pt",
).to("cuda:0")
outputs = model.generate(
**inputs,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
for sent in tokenizer.batch_decode(outputs):
print(sent)
Output:
=== My Sampling ===
Today is a nice day for the world to celebrate our country's independence.
How are you?<|endoftext|>The answer is simple. You're a young man
=== HuggingFace ===
Today is a nice day for the world to celebrate our country's independence.
<|endoftext|><|endoftext|><|endoftext|><|endoftext|>
How are you?<|endoftext|>The answer is simple: You're not a human being. Your brain works
As we can see, the outputs from both implementations are quite similar.